diff --git a/.azure-pipelines/azure-pipelines.yml b/.azure-pipelines/azure-pipelines.yml deleted file mode 100644 index 4115f910..00000000 --- a/.azure-pipelines/azure-pipelines.yml +++ /dev/null @@ -1,131 +0,0 @@ -trigger: - branches: - include: - - '*' - tags: - include: - - '*' - -stages: -- stage: static - displayName: Static Analysis - jobs: - - job: checks - displayName: static code analysis - pool: - vmImage: ubuntu-20.04 - steps: - # Use Python >=3.8 for syntax validation - - task: UsePythonVersion@0 - displayName: Set up python - inputs: - versionSpec: 3.8 - - # Run syntax validation on a shallow clone - - bash: | - python .azure-pipelines/syntax-validation.py - displayName: Syntax validation - - # Run flake8 validation on a shallow clone - - bash: | - pip install --disable-pip-version-check flake8 - python .azure-pipelines/flake8-validation.py - displayName: Flake8 validation - -- stage: build - displayName: Build - dependsOn: - jobs: - - job: build - displayName: build package - pool: - vmImage: ubuntu-20.04 - steps: - - task: UsePythonVersion@0 - displayName: Set up python - inputs: - versionSpec: 3.9 - - - bash: | - pip install --disable-pip-version-check collective.checkdocs wheel - displayName: Install dependencies - - - bash: | - set -ex - python setup.py sdist bdist_wheel - mkdir -p dist/pypi - shopt -s extglob - mv -v dist/!(pypi) dist/pypi - git archive HEAD | gzip > dist/repo-source.tar.gz - ls -laR dist - displayName: Build python package - - - task: PublishBuildArtifacts@1 - displayName: Store artifact - inputs: - pathToPublish: dist/ - artifactName: package - - - bash: python setup.py checkdocs - displayName: Check package description - -- stage: tests - displayName: Run unit tests - dependsOn: - - static - - build - jobs: - - job: linux - pool: - vmImage: ubuntu-20.04 - strategy: - matrix: - python38: - PYTHON_VERSION: 3.8 - python39: - PYTHON_VERSION: 3.9 - python310: - PYTHON_VERSION: 3.10 - python311: - PYTHON_VERSION: 3.11 - steps: - - template: ci.yml - -- stage: deploy - displayName: Publish release - dependsOn: - - tests - condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) - jobs: - - job: pypi - displayName: Publish pypi release - pool: - vmImage: ubuntu-20.04 - steps: - - checkout: none - - - task: UsePythonVersion@0 - displayName: Set up python - inputs: - versionSpec: 3.11 - - - task: DownloadBuildArtifacts@0 - displayName: Get pre-built package - inputs: - buildType: 'current' - downloadType: 'single' - artifactName: 'package' - downloadPath: '$(System.ArtifactsDirectory)' - - - script: | - pip install --disable-pip-version-check twine - displayName: Install twine - - - task: TwineAuthenticate@1 - displayName: Set up credentials - inputs: - pythonUploadServiceConnection: pypi-workflows - - - bash: | - python -m twine upload -r pypi-workflows --config-file $(PYPIRC_PATH) $(System.ArtifactsDirectory)/package/pypi/*.tar.gz $(System.ArtifactsDirectory)/package/pypi/*.whl - displayName: Publish package diff --git a/.azure-pipelines/ci.yml b/.azure-pipelines/ci.yml deleted file mode 100644 index a994af57..00000000 --- a/.azure-pipelines/ci.yml +++ /dev/null @@ -1,70 +0,0 @@ -steps: -- checkout: none - -- bash: | - set -eux - mkdir rabbitmq-docker && cd rabbitmq-docker - - cat <rabbitmq.conf - # allowing remote connections for default user is highly discouraged - # as it dramatically decreases the security of the system. Delete the user - # instead and create a new one with generated secure credentials. - loopback_users = none - EOF - - cat <Dockerfile - FROM rabbitmq:3.9-management - COPY rabbitmq.conf /etc/rabbitmq/rabbitmq.conf - EOF - - docker build -t azure-rabbitmq . - docker run --detach --name rabbitmq -p 127.0.0.1:5672:5672 -p 127.0.0.1:15672:15672 azure-rabbitmq - docker container list -a - displayName: Start RabbitMQ container - workingDirectory: $(Pipeline.Workspace) - -- task: UsePythonVersion@0 - inputs: - versionSpec: '$(PYTHON_VERSION)' - displayName: 'Use Python $(PYTHON_VERSION)' - -- task: DownloadBuildArtifacts@0 - displayName: Get pre-built package - inputs: - buildType: 'current' - downloadType: 'single' - artifactName: 'package' - downloadPath: '$(System.ArtifactsDirectory)' - -- task: ExtractFiles@1 - displayName: Checkout sources - inputs: - archiveFilePatterns: "$(System.ArtifactsDirectory)/package/repo-source.tar.gz" - destinationFolder: "$(Pipeline.Workspace)/src" - -- script: | - set -eux - pip install --disable-pip-version-check -r "$(Pipeline.Workspace)/src/requirements_dev.txt" - pip install --no-deps --disable-pip-version-check -e "$(Pipeline.Workspace)/src" - displayName: Install package - -- script: | - wget -t 10 -w 1 http://127.0.0.1:15672 -O - - displayName: Check RabbitMQ is alive - -- script: | - PYTHONDEVMODE=1 pytest -v -ra --cov=workflows --cov-report=xml --cov-branch - displayName: Run tests - workingDirectory: $(Pipeline.Workspace)/src - -- bash: bash <(curl -s https://codecov.io/bash) -t $(CODECOV_TOKEN) -n "Python $(PYTHON_VERSION) $(Agent.OS)" - displayName: Publish coverage stats - continueOnError: True - workingDirectory: $(Pipeline.Workspace)/src - timeoutInMinutes: 2 - -- script: | - docker logs rabbitmq - docker stop rabbitmq - displayName: Show RabbitMQ logs - condition: succeededOrFailed() diff --git a/.azure-pipelines/flake8-validation.py b/.azure-pipelines/flake8-validation.py deleted file mode 100644 index 7b997fab..00000000 --- a/.azure-pipelines/flake8-validation.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations - -import os -import subprocess - -# Flake8 validation -failures = 0 -try: - flake8 = subprocess.run( - [ - "flake8", - "--exit-zero", - ], - capture_output=True, - check=True, - encoding="latin-1", - timeout=300, - ) -except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: - print( - "##vso[task.logissue type=error;]flake8 validation failed with", - str(e.__class__.__name__), - ) - print(e.stdout) - print(e.stderr) - print("##vso[task.complete result=Failed;]flake8 validation failed") - exit() -for line in flake8.stdout.split("\n"): - if ":" not in line: - continue - filename, lineno, column, error = line.split(":", maxsplit=3) - errcode, error = error.strip().split(" ", maxsplit=1) - filename = os.path.normpath(filename) - failures += 1 - print( - f"##vso[task.logissue type=error;sourcepath={filename};" - f"linenumber={lineno};columnnumber={column};code={errcode};]" + error - ) - -if failures: - print(f"##vso[task.logissue type=warning]Found {failures} flake8 violation(s)") - print(f"##vso[task.complete result=Failed;]Found {failures} flake8 violation(s)") diff --git a/.azure-pipelines/syntax-validation.py b/.azure-pipelines/syntax-validation.py deleted file mode 100644 index 2d74948a..00000000 --- a/.azure-pipelines/syntax-validation.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import ast -import os -import sys - -print("Python", sys.version, "\n") - -failures = 0 - -for base, _, files in os.walk("."): - for f in files: - if not f.endswith(".py"): - continue - filename = os.path.normpath(os.path.join(base, f)) - try: - with open(filename) as fh: - ast.parse(fh.read()) - except SyntaxError as se: - failures += 1 - print( - f"##vso[task.logissue type=error;sourcepath={filename};" - f"linenumber={se.lineno};columnnumber={se.offset};]" - f"SyntaxError: {se.msg}" - ) - print(" " + se.text + " " * se.offset + "^") - print(f"SyntaxError: {se.msg} in {filename} line {se.lineno}") - print() - -if failures: - print(f"##vso[task.logissue type=warning]Found {failures} syntax error(s)") - print(f"##vso[task.complete result=Failed;]Found {failures} syntax error(s)") diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 5c23fc2d..00000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,16 +0,0 @@ -[bumpversion] -current_version = 2.26 -parse = (?P\d+)\.(?P\d+)(\.(?P\d+))? -serialize = - {major}.{minor}.{patch} - {major}.{minor} -commit = True -tag = True - -[bumpversion:file:setup.cfg] -search = version = {current_version} -replace = version = {new_version} - -[bumpversion:file:src/workflows/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 7bb224e8..29ec2510 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -24,18 +24,18 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} queries: +security-and-quality - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 with: category: "/language:${{ matrix.language }}" diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 00000000..8484c06b --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,65 @@ +name: Build and Test + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build distribution + run: python -m build + - uses: actions/upload-artifact@v4 + with: + path: ./dist/* + + test: + needs: build + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: artifact + path: dist + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements_dev.txt dist/*.whl + - name: Test with pytest + run: | + pytest tests + + pypi-publish: + name: Upload release to PyPI + needs: test + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + runs-on: ubuntu-latest + environment: + name: release + url: https://pypi.org/workflows + permissions: + id-token: write + steps: + - uses: actions/download-artifact@v4 + with: + name: artifact + path: dist + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0 diff --git a/.mypy.ini b/.mypy.ini deleted file mode 100644 index 3c756288..00000000 --- a/.mypy.ini +++ /dev/null @@ -1,2 +0,0 @@ -[mypy] -mypy_path=src diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f02679e..7e73c57a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,11 @@ +ci: + skip: [mypy] + autoupdate_schedule: quarterly + repos: # Syntax validation and some basic sanity checks - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v5.0.0 hooks: - id: check-merge-conflict - id: check-ast @@ -11,40 +15,23 @@ repos: args: ['--maxkb=200'] - id: check-yaml -# Automatically sort imports -- repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort - args: [ - '-a', 'from __future__ import annotations', # 3.7-3.11 - '--rm', 'from __future__ import absolute_import', # -3.0 - '--rm', 'from __future__ import division', # -3.0 - '--rm', 'from __future__ import generator_stop', # -3.7 - '--rm', 'from __future__ import generators', # -2.3 - '--rm', 'from __future__ import nested_scopes', # -2.2 - '--rm', 'from __future__ import print_function', # -3.0 - '--rm', 'from __future__ import unicode_literals', # -3.0 - '--rm', 'from __future__ import with_statement', # -2.6 - ] - -# Automatic source code formatting -- repo: https://github.com/psf/black - rev: 22.3.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.8 hooks: - - id: black - args: [--safe, --quiet] + - id: ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] + - id: ruff-format -# Linting -- repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - additional_dependencies: ['flake8-comprehensions==3.8.0'] +# # Linting +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.1.1 +# hooks: +# - id: flake8 +# additional_dependencies: ['flake8-comprehensions==3.8.0'] # Type checking - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v1.13.0 hooks: - id: mypy files: 'src/.*\.py$' diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 1b9ee9dc..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,13 +0,0 @@ -include AUTHORS.rst -include CONTRIBUTING.rst -include HISTORY.rst -include LICENSE -include README.MD - -recursive-include tests * -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] - -recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif - -include src/workflows/py.typed \ No newline at end of file diff --git a/catalog-info.yaml b/catalog-info.yaml new file mode 100644 index 00000000..45435797 --- /dev/null +++ b/catalog-info.yaml @@ -0,0 +1,19 @@ +apiVersion: backstage.io/v1alpha1 +kind: Component +metadata: + name: workflows + title: Workflows + description: > + Workflows enables light-weight services to process tasks in a + message-oriented environment. + annotations: + github.com/project-slug: DiamondLightSource/python-workflows + diamond.ac.uk/viewdocs-url: https://zocalo.readthedocs.io/en/latest/workflows.html + tags: + - python +spec: + type: library + lifecycle: production + owner: group:data-analysis + dependsOn: + - resource:zocalo-rabbitmq diff --git a/pyproject.toml b/pyproject.toml index 74058fc2..d9b1cc7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,90 @@ [build-system] -requires = ["setuptools >= 40.6.0", "wheel"] +requires = ["setuptools>=61.2", "setuptools-scm"] build-backend = "setuptools.build_meta" +[project] +name = "workflows" +version = "2.28" +description = "Data processing in distributed environments" +readme = "README.rst" +authors = [ + { name = "Diamond Light Source", email = "scientificsoftware@diamond.ac.uk" }, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + "Topic :: Software Development :: Libraries :: Python Modules", +] +license = { text = "BSD-3-Clause" } +requires-python = ">=3.10" +dependencies = ["bidict", "pika", "setuptools", "stomp-py>=7"] + +[project.urls] +Download = "https://github.com/DiamondLightSource/python-workflows/releases" +Documentation = "https://github.com/DiamondLightSource/python-workflows" +GitHub = "https://github.com/DiamondLightSource/python-workflows" +Bug-Tracker = "https://github.com/DiamondLightSource/python-workflows/issues" + +[project.optional-dependencies] +prometheus = ["prometheus-client"] + +[project.entry-points."libtbx.dispatcher.script"] +"workflows.validate_recipe" = "workflows.validate_recipe" + +[project.entry-points."libtbx.precommit"] +workflows = "workflows" + +[project.entry-points."workflows.services"] +SampleConsumer = "workflows.services.sample_consumer:SampleConsumer" +SampleProducer = "workflows.services.sample_producer:SampleProducer" +SampleTxn = "workflows.services.sample_transaction:SampleTxn" +SampleTxnProducer = "workflows.services.sample_transaction:SampleTxnProducer" + +[project.entry-points."workflows.transport"] +PikaTransport = "workflows.transport.pika_transport:PikaTransport" +StompTransport = "workflows.transport.stomp_transport:StompTransport" +OfflineTransport = "workflows.transport.offline_transport:OfflineTransport" + +[project.entry-points."zocalo.configuration.plugins"] +pika = "workflows.util.zocalo.configuration:Pika" +stomp = "workflows.util.zocalo.configuration:Stomp" +transport = "workflows.util.zocalo.configuration:DefaultTransport" + +[project.scripts] +"workflows.validate_recipe" = "workflows.recipe.validate:main" + [tool.isort] profile = "black" [tool.pytest.ini_options] addopts = "-ra" required_plugins = "pytest-timeout" + +[tool.bumpversion] +current_version = "2.28" +parse = '(?P\d+)\.(?P\d+)' +serialize = ["{major}.{minor}"] +commit = true +tag = true + +[[tool.bumpversion.files]] +filename = "pyproject.toml" + +[[tool.bumpversion.files]] +filename = "src/workflows/__init__.py" + +[tool.ruff.lint] +select = ["E", "F", "W", "C4", "I"] +unfixable = ["F841"] +# E501 line too long (handled by formatter) +ignore = ["E501"] + +[tool.ruff.lint.isort] +known-first-party = ["dxtbx_*", "dxtbx"] +required-imports = ["from __future__ import annotations"] + +[tool.mypy] +mypy_path = "src/" diff --git a/requirements_dev.txt b/requirements_dev.txt index 54129b3f..8207c45b 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,9 +1,9 @@ -bidict==0.22.1 -prometheus-client==0.15.0 -pytest==7.2.0 -pytest-cov==4.0.0 -pytest-mock==3.10.0 -pytest-timeout==2.1.0 -setuptools==65.6.3 -stomp.py==8.1.0 -pika==1.3.1 +bidict==0.23.1 +pika==1.3.2 +prometheus_client==0.21.0 +pytest==8.3.3 +pytest-cov==6.0.0 +pytest-mock==3.14.0 +pytest-timeout==2.3.1 +stomp-py==8.1.2 +websocket-client==1.8.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1dd49e1b..00000000 --- a/setup.cfg +++ /dev/null @@ -1,82 +0,0 @@ -[metadata] -name = workflows -version = 2.26 -description = Data processing in distributed environments -long_description = file: README.rst -author = Diamond Light Source - Scientific Software et al. -author_email = scientificsoftware@diamond.ac.uk -license = BSD -license_file = LICENSE -classifiers = - Development Status :: 5 - Production/Stable - Intended Audience :: Developers - License :: OSI Approved :: BSD License - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Operating System :: OS Independent - Topic :: Software Development :: Libraries :: Python Modules -project_urls = - Download = https://github.com/DiamondLightSource/python-workflows/releases - Documentation = https://github.com/DiamondLightSource/python-workflows - GitHub = https://github.com/DiamondLightSource/python-workflows - Bug-Tracker = https://github.com/DiamondLightSource/python-workflows/issues - -[options] -install_requires = - bidict - pika - setuptools - stomp.py>=7 -packages = find: -package_dir = - =src -python_requires = >=3.8 -zip_safe = False -include_package_data = True - -[options.extras_require] -prometheus = prometheus-client - -[options.entry_points] -console_scripts = - workflows.validate_recipe = workflows.recipe.validate:main -libtbx.dispatcher.script = - workflows.validate_recipe = workflows.validate_recipe -libtbx.precommit = - workflows = workflows -workflows.services = - SampleConsumer = workflows.services.sample_consumer:SampleConsumer - SampleProducer = workflows.services.sample_producer:SampleProducer - SampleTxn = workflows.services.sample_transaction:SampleTxn - SampleTxnProducer = workflows.services.sample_transaction:SampleTxnProducer -workflows.transport = - PikaTransport = workflows.transport.pika_transport:PikaTransport - StompTransport = workflows.transport.stomp_transport:StompTransport - OfflineTransport = workflows.transport.offline_transport:OfflineTransport -zocalo.configuration.plugins = - pika = workflows.util.zocalo.configuration:Pika - stomp = workflows.util.zocalo.configuration:Stomp - transport = workflows.util.zocalo.configuration:DefaultTransport - -[options.packages.find] -where = src - -[flake8] -# Black disagrees with flake8 on a few points. Ignore those. -ignore = E203, E266, E501, W503 -# E203 whitespace before ':' -# E266 too many leading '#' for block comment -# E501 line too long -# W503 line break before binary operator - -max-line-length = 88 - -select = - E401,E711,E712,E713,E714,E721,E722,E901, - F401,F402,F403,F405,F541,F631,F632,F633,F811,F812,F821,F822,F841,F901, - W191,W291,W292,W293,W602,W603,W604,W605,W606, - # flake8-comprehensions, https://github.com/adamchainz/flake8-comprehensions - C4, diff --git a/setup.py b/setup.py deleted file mode 100644 index f2825fc7..00000000 --- a/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env python - -from __future__ import annotations - -import setuptools - -if __name__ == "__main__": - # Do not add any parameters here. Edit setup.cfg instead. - setuptools.setup() diff --git a/src/workflows/__init__.py b/src/workflows/__init__.py index 8d8563d4..6dfead90 100644 --- a/src/workflows/__init__.py +++ b/src/workflows/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__version__ = "2.26" +__version__ = "2.28" def version() -> str: diff --git a/src/workflows/contrib/start_service.py b/src/workflows/contrib/start_service.py index 90179134..bbcfab48 100644 --- a/src/workflows/contrib/start_service.py +++ b/src/workflows/contrib/start_service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from optparse import SUPPRESS_HELP, OptionParser import workflows @@ -62,6 +63,7 @@ def run( # Enumerate all known services known_services = sorted(workflows.services.get_known_services()) + known_services_help = "Known services: " + ", ".join(known_services) if version: version = f"{version} (workflows {workflows.version()})" @@ -76,15 +78,11 @@ def run( parser.add_option( "-s", "--service", - dest="service", metavar="SVC", - default=None, - help="Name of the service to start. Known services: " - + ", ".join(known_services), + help=f"Name of the service to start. {known_services_help}", ) parser.add_option( "--liveness", - dest="liveness", action="store_true", default=False, help=( @@ -95,14 +93,12 @@ def run( ) parser.add_option( "--liveness-port", - dest="liveness_port", default=8000, type="int", help="Expose liveness check endpoint on this port.", ) parser.add_option( "--liveness-timeout", - dest="liveness_timeout", default=30, type="float", help="Timeout for the liveness check (in seconds).", @@ -111,17 +107,15 @@ def run( parser.add_option( "-m", "--metrics", - dest="metrics", action="store_true", default=False, help=( - "Record metrics for this service and expose them on the port defined by" - "the --metrics-port option." + "Record metrics for this service and expose them on the port " + "defined by the --metrics-port option." ), ) parser.add_option( "--metrics-port", - dest="metrics_port", default=8080, type="int", help="Expose metrics via a prometheus endpoint on this port.", @@ -137,6 +131,10 @@ def run( # Call on_parsing hook (options, args) = self.on_parsing(options, args) or (options, args) + # Exit with error if no service has been specified. + if not options.service: + parser.error(f"Please specify a service name. {known_services_help}") + # Create Transport factory transport_factory = workflows.transport.lookup(options.transport) @@ -155,17 +153,34 @@ def on_transport_preparation_hook(): transport_factory = on_transport_preparation_hook - # When service name is specified, check if service exists or can be derived - if options.service and options.service not in known_services: - matching = [s for s in known_services if s.startswith(options.service)] - if not matching: - matching = [ - s - for s in known_services - if s.lower().startswith(options.service.lower()) - ] - if matching and len(matching) == 1: - options.service = matching[0] + # When service name is specified, check if service exists or can be derived. + if options.service not in known_services: + # First check whether the provided service name is a case-insensitive match. + service_lower = options.service.lower() + match = {s.lower(): s for s in known_services}.get(service_lower, None) + match = ( + [match] + if match + # Next, check whether the provided service name is a partial + # case-sensitive match. + else [s for s in known_services if s.startswith(options.service)] + # Next check whether the provided service name is a partial + # case-insensitive match. + or [s for s in known_services if s.lower().startswith(service_lower)] + ) + + # Catch ambiguous partial matches and exit with an error. + if len(match) > 1: + sys.exit( + f"Specified service name {options.service} is ambiguous, partially " + f"matching each of these known services: " + ", ".join(match) + ) + # Otherwise, set the derived service name, if there's a unique match. + elif match: + (options.service,) = match + # Otherwise, exit with an error. + else: + sys.exit(f"Please specify a valid service name. {known_services_help}") kwargs.update( { diff --git a/src/workflows/contrib/status_monitor.py b/src/workflows/contrib/status_monitor.py index d11dfec6..cc109e3d 100644 --- a/src/workflows/contrib/status_monitor.py +++ b/src/workflows/contrib/status_monitor.py @@ -3,7 +3,7 @@ import curses import threading import time -from typing import Any, Dict +from typing import Any import workflows.transport from workflows.services.common_service import CommonService @@ -19,7 +19,7 @@ class Monitor: # pragma: no cover shutdown = False """Set to true to end the main loop and shut down the service monitor.""" - cards: Dict[Any, Any] = {} + cards: dict[Any, Any] = {} """Register card shown for seen services""" border_chars = () diff --git a/src/workflows/frontend/__init__.py b/src/workflows/frontend/__init__.py index ac560bd7..fe7686c0 100644 --- a/src/workflows/frontend/__init__.py +++ b/src/workflows/frontend/__init__.py @@ -368,7 +368,8 @@ def get_status(self): def exponential_backoff(self): """A function that keeps waiting longer and longer the more rapidly it is called. - It can be used to increasingly slow down service starts when they keep failing.""" + It can be used to increasingly slow down service starts when they keep failing. + """ last_service_switch = self._service_starttime if not last_service_switch: return diff --git a/src/workflows/recipe/__init__.py b/src/workflows/recipe/__init__.py index c1c355a9..e4104e59 100644 --- a/src/workflows/recipe/__init__.py +++ b/src/workflows/recipe/__init__.py @@ -2,7 +2,8 @@ import logging import functools -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from workflows.recipe.recipe import Recipe from workflows.recipe.validate import validate_recipe diff --git a/src/workflows/recipe/recipe.py b/src/workflows/recipe/recipe.py index d4744489..e157fc3f 100644 --- a/src/workflows/recipe/recipe.py +++ b/src/workflows/recipe/recipe.py @@ -3,7 +3,7 @@ import copy import json import string -from typing import Any, Dict +from typing import Any import workflows @@ -15,7 +15,7 @@ class Recipe: A recipe describes how all involved services are connected together, how data should be passed and how errors should be handled.""" - recipe: Dict[Any, Any] = {} + recipe: dict[Any, Any] = {} """The processing recipe is encoded in this dictionary.""" # TODO: Describe format @@ -334,7 +334,7 @@ def translate(x): new_recipe[idx]["error"] = translate(new_recipe[idx]["error"]) # Join 'start' nodes - for (idx, param) in other.recipe["start"]: + for idx, param in other.recipe["start"]: new_recipe["start"].append((translate(idx), param)) # Join 'error' nodes diff --git a/src/workflows/recipe/validate.py b/src/workflows/recipe/validate.py index 10d6d117..7bb72063 100644 --- a/src/workflows/recipe/validate.py +++ b/src/workflows/recipe/validate.py @@ -77,5 +77,4 @@ def main(): if __name__ == "__main__": - main() diff --git a/src/workflows/recipe/wrapper.py b/src/workflows/recipe/wrapper.py index 34c39450..c15de95d 100644 --- a/src/workflows/recipe/wrapper.py +++ b/src/workflows/recipe/wrapper.py @@ -2,7 +2,8 @@ import logging import time -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import workflows.recipe diff --git a/src/workflows/services/__init__.py b/src/workflows/services/__init__.py index 9b008ab7..5a6a1d45 100644 --- a/src/workflows/services/__init__.py +++ b/src/workflows/services/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pkg_resources +from importlib.metadata import entry_points def lookup(service: str): @@ -25,10 +25,7 @@ def get_known_services(): setattr( get_known_services, "cache", - { - e.name: e.load - for e in pkg_resources.iter_entry_points("workflows.services") - }, + {e.name: e.load for e in entry_points(group="workflows.services")}, ) register = get_known_services.cache.copy() return register diff --git a/src/workflows/services/common_service.py b/src/workflows/services/common_service.py index bf7f93c2..de2ef704 100644 --- a/src/workflows/services/common_service.py +++ b/src/workflows/services/common_service.py @@ -7,7 +7,7 @@ import queue import threading import time -from typing import Any, Dict +from typing import Any import workflows import workflows.logging @@ -128,7 +128,7 @@ def in_shutdown(self): # Any keyword arguments set on service invocation - start_kwargs: Dict[Any, Any] = {} + start_kwargs: dict[Any, Any] = {} # Not so overrideable functions --------------------------------------------- diff --git a/src/workflows/transport/__init__.py b/src/workflows/transport/__init__.py index 0ede6b1b..019a6b91 100644 --- a/src/workflows/transport/__init__.py +++ b/src/workflows/transport/__init__.py @@ -2,9 +2,8 @@ import argparse import optparse -from typing import TYPE_CHECKING, Type - -import pkg_resources +from importlib.metadata import entry_points +from typing import TYPE_CHECKING if TYPE_CHECKING: from .common_transport import CommonTransport @@ -12,7 +11,7 @@ default_transport = "PikaTransport" -def lookup(transport: str) -> Type[CommonTransport]: +def lookup(transport: str) -> type[CommonTransport]: """Get a transport layer class based on its name.""" return get_known_transports().get( transport, get_known_transports()[default_transport] @@ -55,15 +54,12 @@ def add_command_line_options( transport().add_command_line_options(parser) -def get_known_transports() -> dict[str, Type[CommonTransport]]: +def get_known_transports() -> dict[str, type[CommonTransport]]: """Return a dictionary of all known transport mechanisms.""" if not hasattr(get_known_transports, "cache"): setattr( get_known_transports, "cache", - { - e.name: e.load() - for e in pkg_resources.iter_entry_points("workflows.transport") - }, + {e.name: e.load() for e in entry_points(group="workflows.transport")}, ) return get_known_transports.cache.copy() # type: ignore diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index 03537ab7..f97cb9fb 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -2,7 +2,8 @@ import decimal import logging -from typing import Any, Callable, Dict, Mapping, NamedTuple, Optional, Set, Type +from collections.abc import Callable, Mapping +from typing import Any, NamedTuple import workflows from workflows.transport import middleware @@ -20,9 +21,9 @@ class CommonTransport: subscriptions and transactions.""" __callback_interceptor = None - __subscriptions: Dict[int, Dict[str, Any]] = {} + __subscriptions: dict[int, dict[str, Any]] = {} __subscription_id: int = 0 - __transactions: Set[int] = set() + __transactions: set[int] = set() __transaction_id: int = 0 log = logging.getLogger("workflows.transport") @@ -32,14 +33,14 @@ class CommonTransport: # def __init__( - self, middleware: list[Type[middleware.BaseTransportMiddleware]] = None + self, middleware: list[type[middleware.BaseTransportMiddleware]] = None ): if middleware is None: self.middleware = [] else: self.middleware = middleware - def add_middleware(self, middleware: Type[middleware.BaseTransportMiddleware]): + def add_middleware(self, middleware: type[middleware.BaseTransportMiddleware]): self.middleware.insert(0, middleware) @classmethod @@ -99,7 +100,7 @@ def mangled_callback(header, message): @middleware.wrap def subscribe_temporary( - self, channel_hint: Optional[str], callback: MessageCallback, **kwargs + self, channel_hint: str | None, callback: MessageCallback, **kwargs ) -> TemporarySubscription: """Listen to a new queue that is specifically created for this connection, and has a limited lifetime. Notify for messages via callback function. @@ -320,7 +321,7 @@ def broadcast_status(self, status: dict) -> None: raise NotImplementedError @middleware.wrap - def ack(self, message, subscription_id: Optional[int] = None, **kwargs): + def ack(self, message, subscription_id: int | None = None, **kwargs): """Acknowledge receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. :param message: ID of the message to be acknowledged, OR a dictionary @@ -351,7 +352,7 @@ def ack(self, message, subscription_id: Optional[int] = None, **kwargs): self._ack(message_id, subscription_id=subscription_id, **kwargs) @middleware.wrap - def nack(self, message, subscription_id: Optional[int] = None, **kwargs): + def nack(self, message, subscription_id: int | None = None, **kwargs): """Reject receipt of a message. This only makes sense when the 'acknowledgement' flag was set for the relevant subscription. :param message: ID of the message to be rejected, OR a dictionary @@ -380,7 +381,7 @@ def nack(self, message, subscription_id: Optional[int] = None, **kwargs): self._nack(message_id, subscription_id=subscription_id, **kwargs) @middleware.wrap - def transaction_begin(self, subscription_id: Optional[int] = None, **kwargs) -> int: + def transaction_begin(self, subscription_id: int | None = None, **kwargs) -> int: """Start a new transaction. :param **kwargs: Further parameters for the transport layer. :return: A transaction ID that can be passed to other functions. @@ -462,7 +463,7 @@ def _subscribe_broadcast(self, sub_id: int, channel, callback, **kwargs): def _subscribe_temporary( self, sub_id: int, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> str: @@ -530,7 +531,7 @@ def _nack(self, message_id, subscription_id, **kwargs): raise NotImplementedError("Transport interface not implemented") def _transaction_begin( - self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs + self, transaction_id: int, *, subscription_id: int | None = None, **kwargs ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. diff --git a/src/workflows/transport/middleware/__init__.py b/src/workflows/transport/middleware/__init__.py index a8227e11..0b2ad2ef 100644 --- a/src/workflows/transport/middleware/__init__.py +++ b/src/workflows/transport/middleware/__init__.py @@ -4,7 +4,8 @@ import inspect import logging import time -from typing import TYPE_CHECKING, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING if TYPE_CHECKING: from workflows.transport.common_transport import ( @@ -36,7 +37,7 @@ def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: def subscribe_temporary( self, call_next: Callable, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> TemporarySubscription: @@ -74,7 +75,7 @@ def ack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): call_next(message, subscription_id=subscription_id, **kwargs) @@ -83,13 +84,13 @@ def nack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): call_next(message, subscription_id=subscription_id, **kwargs) def transaction_begin( - self, call_next: Callable, subscription_id: Optional[int] = None, **kwargs + self, call_next: Callable, subscription_id: int | None = None, **kwargs ) -> int: return call_next(subscription_id=subscription_id, **kwargs) @@ -136,7 +137,7 @@ def ack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): call_next(message, subscription_id=subscription_id, **kwargs) @@ -147,7 +148,7 @@ def nack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): call_next(message, subscription_id=subscription_id, **kwargs) @@ -195,7 +196,7 @@ def wrapped_callback(header, message): def subscribe_temporary( self, call_next: Callable, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> TemporarySubscription: @@ -234,7 +235,6 @@ def wrapped_callback(header, message): def wrap(f: Callable): @functools.wraps(f) def wrapper(self, *args, **kwargs): - return functools.reduce( lambda call_next, m: lambda *args, **kwargs: getattr(m, f.__name__)( call_next, *args, **kwargs diff --git a/src/workflows/transport/middleware/prometheus.py b/src/workflows/transport/middleware/prometheus.py index 9fa2d81c..6dd57f42 100644 --- a/src/workflows/transport/middleware/prometheus.py +++ b/src/workflows/transport/middleware/prometheus.py @@ -2,7 +2,7 @@ import functools import time -from typing import Callable, Optional +from collections.abc import Callable from prometheus_client import Counter, Gauge, Histogram @@ -100,7 +100,7 @@ def wrapped_callback(header, message): def subscribe_temporary( self, call_next: Callable, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> TemporarySubscription: @@ -155,7 +155,7 @@ def ack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): ACKS.labels(source=self.source).inc() @@ -165,7 +165,7 @@ def nack( self, call_next: Callable, message, - subscription_id: Optional[int] = None, + subscription_id: int | None = None, **kwargs, ): NACKS.labels(source=self.source).inc() diff --git a/src/workflows/transport/offline_transport.py b/src/workflows/transport/offline_transport.py index 71c56b85..b4af8711 100644 --- a/src/workflows/transport/offline_transport.py +++ b/src/workflows/transport/offline_transport.py @@ -6,7 +6,7 @@ import logging import pprint import uuid -from typing import Any, Dict, Optional, Type +from typing import Any import workflows.util from workflows.transport import middleware @@ -23,12 +23,12 @@ class OfflineTransport(CommonTransport): """Abstraction layer for messaging infrastructure. Here we.. do nothing.""" # Add for compatibility - defaults: Dict[Any, Any] = {} + defaults: dict[Any, Any] = {} # Effective configuration - config: Dict[Any, Any] = {} + config: dict[Any, Any] = {} def __init__( - self, middleware: list[Type[middleware.BaseTransportMiddleware]] = None + self, middleware: list[type[middleware.BaseTransportMiddleware]] = None ): self._connected = False super().__init__(middleware=middleware) @@ -60,7 +60,7 @@ def _subscribe(self, sub_id, channel, callback, **kwargs): def _subscribe_temporary( self, sub_id: int, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> str: diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index f1b5a055..2a8f5bd1 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -10,9 +10,10 @@ import threading import time import uuid +from collections.abc import Callable, Iterable from concurrent.futures import Future from enum import Enum, auto -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any import pika.exceptions from bidict import bidict @@ -59,10 +60,10 @@ class PikaTransport(CommonTransport): } # Effective configuration - config: Dict[Any, Any] = {} + config: dict[Any, Any] = {} def __init__( - self, middleware: list[Type[middleware.BaseTransportMiddleware]] = None + self, middleware: list[type[middleware.BaseTransportMiddleware]] = None ): self._channel = None self._conn = None @@ -240,7 +241,7 @@ def set_parameter(option, opt, value, parser): callback=set_parameter, ) - def _generate_connection_parameters(self) -> List[pika.ConnectionParameters]: + def _generate_connection_parameters(self) -> list[pika.ConnectionParameters]: username = self.config.get("--rabbit-user", self.defaults.get("--rabbit-user")) password = self.config.get("--rabbit-pass", self.defaults.get("--rabbit-pass")) credentials = pika.PlainCredentials(username, password) @@ -440,7 +441,7 @@ def _subscribe_broadcast( def _subscribe_temporary( self, sub_id: int, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, *, acknowledgement: bool = False, @@ -494,8 +495,8 @@ def _send( headers=None, delay=None, expiration=None, - transaction: Optional[int] = None, - exchange: Optional[str] = None, + transaction: int | None = None, + exchange: str | None = None, **kwargs, ): """ @@ -540,8 +541,8 @@ def _broadcast( message, headers=None, delay=None, - expiration: Optional[int] = None, - transaction: Optional[int] = None, + expiration: int | None = None, + transaction: int | None = None, **kwargs, ): """Send a message to a fanout exchange. @@ -579,7 +580,7 @@ def _broadcast( ).result() def _transaction_begin( - self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs + self, transaction_id: int, *, subscription_id: int | None = None, **kwargs ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. @@ -724,13 +725,13 @@ class _PikaSubscription: reconnectable: Are we allowed to reconnect to this subscription """ - arguments: Dict[str, Any] + arguments: dict[str, Any] auto_ack: bool destination: str kind: _PikaSubscriptionKind on_message_callback: PikaCallback = dataclasses.field(repr=False) prefetch_count: int - queue: Optional[str] = dataclasses.field(init=False, default=None) + queue: str | None = dataclasses.field(init=False, default=None) reconnectable: bool @@ -763,25 +764,25 @@ def __init__( # Internal store of subscriptions, to resubscribe if necessary. Keys are # unique and auto-generated, and known as subscription IDs or consumer tags # (strictly: pika/AMQP consumer tags are strings, not integers) - self._subscriptions: Dict[int, _PikaSubscription] = {} + self._subscriptions: dict[int, _PikaSubscription] = {} # The pika connection object - self._connection: Optional[pika.BlockingConnection] = None + self._connection: pika.BlockingConnection | None = None # Index of per-subscription channels. self._pika_channels: bidict[int, BlockingChannel] = bidict() # Bidirectional index of all ongoing transactions. May include the shared channel self._transaction_on_channel: bidict[BlockingChannel, int] = bidict() # Information on whether a channel has uncommitted messages - self._channel_has_active_tx: Dict[BlockingChannel, bool] = {} + self._channel_has_active_tx: dict[BlockingChannel, bool] = {} # Information on whether a channel is transactional - self._channel_is_transactional: Dict[BlockingChannel, bool] = {} + self._channel_is_transactional: dict[BlockingChannel, bool] = {} # A common, shared channel, used for sending messages outside of transactions. - self._pika_shared_channel: Optional[BlockingChannel] + self._pika_shared_channel: BlockingChannel | None # Are we allowed to reconnect. Can only be turned off, never on self._reconnection_allowed: bool = True # Our list of connection parameters, so we know where to connect to self._connection_parameters = list(connection_parameters) # If we failed with an unexpected exception - self._exc_info: Optional[Tuple[Any, Any, Any]] = None + self._exc_info: tuple[Any, Any, Any] | None = None self._reconnection_attempt_limit = reconnection_attempts # General bookkeeping events @@ -829,9 +830,7 @@ def stop(self): except pika.exceptions.ConnectionWrongStateError: pass - def join( - self, timeout: Optional[float] = None, *, re_raise: bool = False, stop=False - ): + def join(self, timeout: float | None = None, *, re_raise: bool = False, stop=False): """Wait until the thread terminates. Args: @@ -1067,10 +1066,10 @@ def send( self, exchange: str, routing_key: str, - body: Union[str, bytes], + body: str | bytes, properties: pika.spec.BasicProperties = None, mandatory: bool = True, - transaction_id: Optional[int] = None, + transaction_id: int | None = None, ) -> Future[None]: """Send a message. Thread-safe.""" @@ -1108,7 +1107,7 @@ def ack( subscription_id: int, *, multiple=False, - transaction_id: Optional[int], + transaction_id: int | None, ): if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to ACK") @@ -1147,7 +1146,7 @@ def nack( *, multiple=False, requeue=True, - transaction_id: Optional[int], + transaction_id: int | None, ): if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to NACK") @@ -1182,7 +1181,7 @@ def _nack_callback(): ) def tx_select( - self, transaction_id: int, subscription_id: Optional[int] + self, transaction_id: int, subscription_id: int | None ) -> Future[None]: """Set a channel to transaction mode. Thread-safe. :param transaction_id: ID for this transaction in the transport layer. diff --git a/src/workflows/transport/stomp_transport.py b/src/workflows/transport/stomp_transport.py index 89ff3b83..f55777eb 100644 --- a/src/workflows/transport/stomp_transport.py +++ b/src/workflows/transport/stomp_transport.py @@ -5,7 +5,7 @@ import threading import time import uuid -from typing import Any, Dict, Optional, Type +from typing import Any import stomp @@ -31,10 +31,10 @@ class StompTransport(CommonTransport): "--stomp-prfx": "", } # Effective configuration - config: Dict[Any, Any] = {} + config: dict[Any, Any] = {} def __init__( - self, middleware: list[Type[middleware.BaseTransportMiddleware]] = None + self, middleware: list[type[middleware.BaseTransportMiddleware]] = None ): self._connected = False self._namespace = "" @@ -348,7 +348,7 @@ def _subscribe_broadcast(self, sub_id, channel, callback, **kwargs): def _subscribe_temporary( self, sub_id: int, - channel_hint: Optional[str], + channel_hint: str | None, callback: MessageCallback, **kwargs, ) -> str: diff --git a/tests/contrib/test_start_service.py b/tests/contrib/test_start_service.py index 7b5471a3..b65c9cf3 100644 --- a/tests/contrib/test_start_service.py +++ b/tests/contrib/test_start_service.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows.contrib.start_service @@ -55,7 +54,6 @@ def test_script_initialises_transport_and_starts_frontend( @mock.patch("workflows.contrib.start_service.workflows.frontend") @mock.patch("workflows.contrib.start_service.workflows.services") def test_add_metrics_option(mock_services, mock_frontend, mock_tlookup, mock_parser): - mock_options = mock.Mock() mock_options.service = "someservice" mock_options.transport = mock.sentinel.transport diff --git a/tests/frontend/test_frontend.py b/tests/frontend/test_frontend.py index 15cf6d29..7846800c 100644 --- a/tests/frontend/test_frontend.py +++ b/tests/frontend/test_frontend.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows.frontend from workflows.services.common_service import CommonService diff --git a/tests/recipe/test_recipe.py b/tests/recipe/test_recipe.py index c27eb8f2..0320ce9f 100644 --- a/tests/recipe/test_recipe.py +++ b/tests/recipe/test_recipe.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows import workflows.recipe @@ -363,8 +362,6 @@ def test_merging_recipes(): # There is a 'C service' assert any( - map( - lambda x: (isinstance(x, dict) and x.get("service") == "C service"), - C.recipe.values(), - ) + isinstance(x, dict) and x.get("service") == "C service" + for x in C.recipe.values() ) diff --git a/tests/recipe/test_validate.py b/tests/recipe/test_validate.py index 46bed6fc..1eeedc5f 100644 --- a/tests/recipe/test_validate.py +++ b/tests/recipe/test_validate.py @@ -9,7 +9,6 @@ from unittest import mock import pytest - import workflows from workflows.recipe.validate import main, validate_recipe diff --git a/tests/recipe/test_wrapped_recipe.py b/tests/recipe/test_wrapped_recipe.py index bb5dfa7e..2f7fba81 100644 --- a/tests/recipe/test_wrapped_recipe.py +++ b/tests/recipe/test_wrapped_recipe.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows.transport.common_transport from workflows.recipe import Recipe from workflows.recipe.wrapper import RecipeWrapper diff --git a/tests/services/test.py b/tests/services/test.py index 388b4d47..e5821245 100644 --- a/tests/services/test.py +++ b/tests/services/test.py @@ -1,8 +1,21 @@ from __future__ import annotations import workflows.services +from workflows.services.common_service import CommonService def test_known_services_is_a_dictionary(): """Check services register build in CommonService.""" assert isinstance(workflows.services.get_known_services(), dict) + + +def test_enumerate_services(): + """Verify we can discover the installed services.""" + services = workflows.services.get_known_services() + assert services.keys() == { + "SampleConsumer", + "SampleProducer", + "SampleTxn", + "SampleTxnProducer", + } + assert all(issubclass(service(), CommonService) for service in services.values()) diff --git a/tests/services/test_common_service.py b/tests/services/test_common_service.py index a738a3e4..d6878172 100644 --- a/tests/services/test_common_service.py +++ b/tests/services/test_common_service.py @@ -5,7 +5,6 @@ from unittest import mock import pytest - from workflows.services.common_service import Commands, CommonService, Priority @@ -121,7 +120,8 @@ def test_log_message_fieldvalue_pairs_are_removed_outside_their_context(): def test_log_message_fieldvalue_pairs_are_attached_to_unhandled_exceptions_and_logged_properly(): """When an exception falls through the extend_log context handler the fields are removed from future log messages, - but they are also attached to the exception object, as they may contain valuable information for debugging.""" + but they are also attached to the exception object, as they may contain valuable information for debugging. + """ fe_pipe = mock.Mock() service = CommonService() service.connect(frontend=fe_pipe) diff --git a/tests/services/test_sample_producer.py b/tests/services/test_sample_producer.py index 6e43c98b..102ddf51 100644 --- a/tests/services/test_sample_producer.py +++ b/tests/services/test_sample_producer.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows.services import workflows.services.sample_producer from workflows.transport.offline_transport import OfflineTransport diff --git a/tests/transport/test.py b/tests/transport/test.py index 9c55a1f8..91c4fa59 100644 --- a/tests/transport/test.py +++ b/tests/transport/test.py @@ -1,11 +1,23 @@ from __future__ import annotations import workflows.transport +from workflows.transport.common_transport import CommonTransport def test_known_transports_is_a_dictionary(): """Check transport register build in CommonTransport.""" - assert isinstance(workflows.transport.get_known_transports(), dict) + transports = workflows.transport.get_known_transports() + print(transports) + assert isinstance(transports, dict) + + +def test_enumerate_transports(): + """Verify we can discover the installed transports.""" + transports = workflows.transport.get_known_transports() + assert transports.keys() == {"OfflineTransport", "PikaTransport", "StompTransport"} + assert all( + issubclass(transport, CommonTransport) for transport in transports.values() + ) def test_load_any_transport(): diff --git a/tests/transport/test_common.py b/tests/transport/test_common.py index c30710cd..6275ca57 100644 --- a/tests/transport/test_common.py +++ b/tests/transport/test_common.py @@ -3,7 +3,6 @@ from unittest import mock import pytest - import workflows from workflows.transport.common_transport import CommonTransport @@ -167,7 +166,9 @@ def test_callbacks_can_be_intercepted(mangling): ) # Pass through (tests the value passed to the interceptor function is sensible) - intercept = lambda x: x + def intercept(x): + return x + ct.subscription_callback_set_intercept(intercept) ct.subscription_callback(subid)(mock.sentinel.header, mock.sentinel.message) diff --git a/tests/transport/test_middleware.py b/tests/transport/test_middleware.py index 0b7bd7f6..6c265f35 100644 --- a/tests/transport/test_middleware.py +++ b/tests/transport/test_middleware.py @@ -6,7 +6,6 @@ from unittest import mock import pytest - from workflows.transport import middleware from workflows.transport.offline_transport import OfflineTransport diff --git a/tests/transport/test_offline.py b/tests/transport/test_offline.py index d3fd2d1b..1d94839b 100644 --- a/tests/transport/test_offline.py +++ b/tests/transport/test_offline.py @@ -5,7 +5,6 @@ from unittest import mock import pytest - import workflows.transport from workflows.transport.offline_transport import OfflineTransport diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index e83339d8..9f075ec3 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -13,7 +13,6 @@ import pika import pytest - import workflows.transport.pika_transport from workflows.transport.common_transport import TemporarySubscription from workflows.transport.pika_transport import PikaTransport, _PikaThread @@ -353,7 +352,8 @@ def test_sending_message_with_expiration(mockpika, mock_pikathread): @mock.patch("workflows.transport.pika_transport.pika") def test_error_handling_on_send(mockpika, mock_pikathread): """Unrecoverable errors during sending should lead to one reconnection attempt. - Further errors should raise an Exception, further send attempts to try to reconnect.""" + Further errors should raise an Exception, further send attempts to try to reconnect. + """ pytest.xfail("Don't understand send failure modes yet") @@ -468,7 +468,8 @@ def test_broadcasting_message_with_expiration(mockpika, mock_pikathread): @mock.patch("workflows.transport.pika_transport.pika") def test_error_handling_on_broadcast(mockpika): """Unrecoverable errors during broadcasting should lead to one reconnection attempt. - Further errors should raise an Exception, further send attempts to try to reconnect.""" + Further errors should raise an Exception, further send attempts to try to reconnect. + """ pytest.xfail("Don't understand send lifecycle errors yet") transport = PikaTransport() transport.connect() @@ -913,7 +914,7 @@ def register_cleanup(self, task): yield channel finally: # Make an attempt to run all of the shutdown tasks - for (filename, lineno, task) in reversed(channel._on_close): + for filename, lineno, task in reversed(channel._on_close): try: print(f"Cleaning up from {filename}:{lineno}") task() @@ -1156,7 +1157,6 @@ def _get_message(*args): def test_pikathread_ack_transaction(test_channel, connection_params): - queue = test_channel.temporary_queue_declare() thread = _PikaThread(connection_params) try: @@ -1203,7 +1203,6 @@ def _get_message(channel, method_frame, header_frame, body): def test_pikathread_nack_transaction(test_channel, connection_params): - queue = test_channel.temporary_queue_declare() thread = _PikaThread(connection_params) try: @@ -1252,7 +1251,6 @@ def _get_message(channel, method_frame, header_frame, body): def test_pikathread_tx_rollback_nack(test_channel, connection_params): - queue = test_channel.temporary_queue_declare() thread = _PikaThread(connection_params) try: diff --git a/tests/transport/test_stomp.py b/tests/transport/test_stomp.py index 9deaf73b..2123b86a 100644 --- a/tests/transport/test_stomp.py +++ b/tests/transport/test_stomp.py @@ -13,7 +13,6 @@ import pytest import stomp as stomppy - import workflows import workflows.transport from workflows.transport.common_transport import TemporarySubscription