diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml deleted file mode 100644 index ab34859c5..000000000 --- a/.github/workflows/integration-test.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: Run ocrd network integration tests - -on: - push: - branches: [ "master" ] - pull_request: - branches: [ "master" ] - -jobs: - build: - - runs-on: ${{ matrix.os }} - - strategy: - fail-fast: false - matrix: - python-version: - - '3.7' - - '3.8' - - '3.9' - - '3.10' - - '3.11' - os: - - ubuntu-22.04 - # - macos-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up Homebrew - id: set-up-homebrew - uses: Homebrew/actions/setup-homebrew@master - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - if [[ "${{ matrix.os }}" == "ubuntu"* ]];then - sudo apt-get -y update - sudo make deps-ubuntu - else - HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 \ - HOMEBREW_NO_AUTO_UPDATE=1 \ - brew install imagemagick geos bash # opencv - fi - make install deps-test - - name: Install Docker on macOS - if: runner.os == 'macos' - run: | - brew install docker docker-compose - colima start - - name: Test network integration with pytest - run: | - if [[ "${{ matrix.os }}" == "macos"* ]];then - make integration-test DOCKER_COMPOSE=docker-compose - else - make integration-test - fi diff --git a/.github/workflows/network-testing.yml b/.github/workflows/network-testing.yml index 23913f2db..484ffa23e 100644 --- a/.github/workflows/network-testing.yml +++ b/.github/workflows/network-testing.yml @@ -5,6 +5,7 @@ on: branches: [ "master" ] pull_request: branches: [ "master" ] + workflow_dispatch: # run manually jobs: build: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index cfe282cd5..d83948c77 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -26,7 +26,7 @@ jobs: os: - ubuntu-22.04 - ubuntu-20.04 - - macos-latest + # - macos-latest steps: - uses: actions/checkout@v3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 46b67a414..3cfee9252 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2038,7 +2038,10 @@ Fixed Initial Release +<<<<<<< HEAD +======= [2.63.3]: ../../compare/v2.63.3..v2.63.1 +>>>>>>> master [2.63.2]: ../../compare/v2.63.2..v2.63.1 [2.63.1]: ../../compare/v2.63.1..v2.63.0 [2.63.0]: ../../compare/v2.63.0..v2.62.0 diff --git a/Makefile b/Makefile index 1f0777abe..5b83d1a17 100644 --- a/Makefile +++ b/Makefile @@ -39,6 +39,9 @@ help: @echo " docker Build docker image" @echo " docker-cuda Build docker image for GPU / CUDA" @echo " pypi Build wheels and source dist and twine upload them" + @echo " ocrd network tests" + @echo " network-module-test Run all ocrd_network module tests" + @echo " network-integration-test Run all ocrd_network integration tests (docker and docker compose required)" @echo "" @echo " Variables" @echo "" @@ -217,12 +220,32 @@ test: assets --ignore-glob="$(TESTDIR)/**/*bench*.py" \ --ignore-glob="$(TESTDIR)/network/*.py" \ $(TESTDIR) - cd ocrd_utils ; $(PYTHON) -m pytest --continue-on-collection-errors -k TestLogging -k TestDecorators $(TESTDIR) + $(MAKE) test-logging + +test-logging: assets + # copy default logging to temporary directory and run logging tests from there + tempdir=$$(mktemp -d); \ + cp src/ocrd_utils/ocrd_logging.conf $$tempdir; \ + cd $$tempdir; \ + $(PYTHON) -m pytest --continue-on-collection-errors -k TestLogging -k TestDecorators $(TESTDIR); \ + rm -r $$tempdir/ocrd_logging.conf $$tempdir/.benchmarks; \ + rmdir $$tempdir + +network-module-test: assets + $(PYTHON) \ + -m pytest $(PYTEST_ARGS) -k 'test_modules_' -v --durations=10\ + --ignore-glob="$(TESTDIR)/network/test_integration_*.py" \ + $(TESTDIR)/network INTEGRATION_TEST_IN_DOCKER = docker exec core_test -integration-test: +network-integration-test: + $(DOCKER_COMPOSE) --file tests/network/docker-compose.yml up -d + -$(INTEGRATION_TEST_IN_DOCKER) pytest -k 'test_integration_' -v + $(DOCKER_COMPOSE) --file tests/network/docker-compose.yml down --remove-orphans + +network-integration-test-cicd: $(DOCKER_COMPOSE) --file tests/network/docker-compose.yml up -d - -$(INTEGRATION_TEST_IN_DOCKER) pytest -k 'test_rmq or test_db or test_processing_server' -v + $(INTEGRATION_TEST_IN_DOCKER) pytest -k 'test_integration_' -v $(DOCKER_COMPOSE) --file tests/network/docker-compose.yml down --remove-orphans benchmark: diff --git a/README.md b/README.md index 7999ec181..c37d28400 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ A minimal [OCR-D processor](https://ocr-d.de/en/user_guide#using-the-ocr-d-proce Almost all behaviour of the OCR-D/core software is configured via CLI options and flags, which can be listed with the `--help` flag that all CLI support. -Some parts of the software are configured via environement variables: +Some parts of the software are configured via environment variables: * `OCRD_METS_CACHING`: If set to `true`, access to the METS file is cached, speeding in-memory search and modification. * `OCRD_PROFILE`: This variable configures the built-in CPU and memory profiling. If empty, no profiling is done. Otherwise expected to contain any of the following tokens: @@ -105,7 +105,7 @@ Some parts of the software are configured via environement variables: * `OCRD_NETWORK_SERVER_ADDR_PROCESSING`: Default address of Processing Server to connect to (for `ocrd network client processing`). * `OCRD_NETWORK_SERVER_ADDR_WORKFLOW`: Default address of Workflow Server to connect to (for `ocrd network client workflow`). * `OCRD_NETWORK_SERVER_ADDR_WORKSPACE`: Default address of Workspace Server to connect to (for `ocrd network client workspace`). -* `OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS`: Number of attempts for a worker to create its queue. Helpfull if the rabbitmq-server needs time to be fully started. +* `OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS`: Number of attempts for a worker to create its queue. Helpful if the rabbitmq-server needs time to be fully started. ## Packages diff --git a/src/ocrd/cli/__init__.py b/src/ocrd/cli/__init__.py index 9bfa21276..89b5d7554 100644 --- a/src/ocrd/cli/__init__.py +++ b/src/ocrd/cli/__init__.py @@ -43,7 +43,7 @@ \b {config.describe('OCRD_NETWORK_SERVER_ADDR_WORKSPACE')} \b -{config.describe('OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS')} +{config.describe('OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS')} \b {config.describe('OCRD_PROFILE_FILE')} \b diff --git a/src/ocrd/cli/network.py b/src/ocrd/cli/network.py index 9203b8840..72ecefae4 100644 --- a/src/ocrd/cli/network.py +++ b/src/ocrd/cli/network.py @@ -7,7 +7,6 @@ """ import click -import logging from ocrd_utils import initLogging from ocrd_network.cli import ( client_cli, diff --git a/src/ocrd/decorators/__init__.py b/src/ocrd/decorators/__init__.py index ecfef5dbb..4e46dec0b 100644 --- a/src/ocrd/decorators/__init__.py +++ b/src/ocrd/decorators/__init__.py @@ -10,7 +10,7 @@ set_json_key_value_overrides, ) from ocrd_validators import WorkspaceValidator -from ocrd_network import ProcessingWorker, ProcessorServer, NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER +from ocrd_network import ProcessingWorker, ProcessorServer, AgentType from ..resolver import Resolver from ..processor.base import run_processor @@ -20,7 +20,7 @@ from .ocrd_cli_options import ocrd_cli_options from .mets_find_options import mets_find_options -SUBCOMMANDS = [NETWORK_AGENT_WORKER, NETWORK_AGENT_SERVER] +SUBCOMMANDS = [AgentType.PROCESSING_WORKER, AgentType.PROCESSOR_SERVER] def ocrd_cli_wrap_processor( @@ -142,19 +142,19 @@ def check_and_run_network_agent(ProcessorClass, subcommand: str, address: str, d if not database: raise ValueError(f"Option '--database' is invalid for subcommand {subcommand}") - if subcommand == NETWORK_AGENT_SERVER: + if subcommand == AgentType.PROCESSOR_SERVER: if not address: raise ValueError(f"Option '--address' required for subcommand {subcommand}") if queue: raise ValueError(f"Option '--queue' invalid for subcommand {subcommand}") - if subcommand == NETWORK_AGENT_WORKER: + if subcommand == AgentType.PROCESSING_WORKER: if address: raise ValueError(f"Option '--address' invalid for subcommand {subcommand}") if not queue: raise ValueError(f"Option '--queue' required for subcommand {subcommand}") processor = ProcessorClass(workspace=None) - if subcommand == NETWORK_AGENT_WORKER: + if subcommand == AgentType.PROCESSING_WORKER: processing_worker = ProcessingWorker( rabbitmq_addr=queue, mongodb_addr=database, @@ -166,7 +166,7 @@ def check_and_run_network_agent(ProcessorClass, subcommand: str, address: str, d processing_worker.connect_consumer() # Start consuming from the queue with name `processor_name` processing_worker.start_consuming() - elif subcommand == NETWORK_AGENT_SERVER: + elif subcommand == AgentType.PROCESSOR_SERVER: # TODO: Better validate that inside the ProcessorServer itself host, port = address.split(':') processor_server = ProcessorServer( @@ -175,4 +175,6 @@ def check_and_run_network_agent(ProcessorClass, subcommand: str, address: str, d processor_class=ProcessorClass, ) processor_server.run_server(host=host, port=int(port)) + else: + raise ValueError(f"Unknown network agent type, must be one of: {SUBCOMMANDS}") sys.exit(0) diff --git a/src/ocrd/decorators/ocrd_cli_options.py b/src/ocrd/decorators/ocrd_cli_options.py index d1d3a9624..f32955838 100644 --- a/src/ocrd/decorators/ocrd_cli_options.py +++ b/src/ocrd/decorators/ocrd_cli_options.py @@ -1,7 +1,7 @@ import click from click import option, Path, group, command, argument from ocrd_utils import DEFAULT_METS_BASENAME -from ocrd_network import NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER +from ocrd_network import AgentType from .parameter_option import parameter_option, parameter_override_option from .loglevel_option import loglevel_option from ocrd_network import ( @@ -57,7 +57,7 @@ def cli(mets_url): # subcommands. So we have to work around that by creating a # pseudo-subcommand handled in ocrd_cli_wrap_processor argument('subcommand', nargs=1, required=False, - type=click.Choice([NETWORK_AGENT_WORKER, NETWORK_AGENT_SERVER])), + type=click.Choice([AgentType.PROCESSING_WORKER, AgentType.PROCESSOR_SERVER])), ] for param in params: param(f) diff --git a/src/ocrd_network/__init__.py b/src/ocrd_network/__init__.py index 08153e594..189a48100 100644 --- a/src/ocrd_network/__init__.py +++ b/src/ocrd_network/__init__.py @@ -1,10 +1,7 @@ from .client import Client -from .constants import NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER +from .constants import AgentType, JobState from .processing_server import ProcessingServer from .processing_worker import ProcessingWorker from .processor_server import ProcessorServer -from .param_validators import ( - DatabaseParamType, - ServerAddressParamType, - QueueServerParamType -) +from .param_validators import DatabaseParamType, ServerAddressParamType, QueueServerParamType +from .server_cache import CacheLockedPages, CacheProcessingRequests diff --git a/src/ocrd_network/cli/client.py b/src/ocrd_network/cli/client.py index 5c62ac44e..8086658e0 100644 --- a/src/ocrd_network/cli/client.py +++ b/src/ocrd_network/cli/client.py @@ -1,13 +1,9 @@ import click from typing import Optional -from ocrd_utils import DEFAULT_METS_BASENAME - -from ocrd.decorators import ( - parameter_option, - parameter_override_option -) +from ocrd.decorators import parameter_option from ocrd_network import Client +from ocrd_utils import DEFAULT_METS_BASENAME @click.group('client') diff --git a/src/ocrd_network/cli/processing_server.py b/src/ocrd_network/cli/processing_server.py index cf2aacab4..50a42887c 100644 --- a/src/ocrd_network/cli/processing_server.py +++ b/src/ocrd_network/cli/processing_server.py @@ -1,8 +1,5 @@ import click -from .. import ( - ProcessingServer, - ServerAddressParamType -) +from ocrd_network import ProcessingServer, ServerAddressParamType @click.command('processing-server') diff --git a/src/ocrd_network/cli/processing_worker.py b/src/ocrd_network/cli/processing_worker.py index b626e8b55..3af88e56b 100644 --- a/src/ocrd_network/cli/processing_worker.py +++ b/src/ocrd_network/cli/processing_worker.py @@ -1,11 +1,6 @@ import click from ocrd_utils import get_ocrd_tool_json - -from .. import ( - DatabaseParamType, - ProcessingWorker, - QueueServerParamType -) +from ocrd_network import DatabaseParamType, ProcessingWorker, QueueServerParamType @click.command('processing-worker') diff --git a/src/ocrd_network/cli/processor_server.py b/src/ocrd_network/cli/processor_server.py index 534a9a0fe..50529adda 100644 --- a/src/ocrd_network/cli/processor_server.py +++ b/src/ocrd_network/cli/processor_server.py @@ -1,9 +1,5 @@ import click -from .. import ( - DatabaseParamType, - ProcessorServer, - ServerAddressParamType -) +from ocrd_network import DatabaseParamType, ProcessorServer, ServerAddressParamType @click.command('processor-server') @@ -23,12 +19,12 @@ def processor_server_cli(processor_name: str, address: str, database: str): (standalone REST API OCR-D processor) """ try: - # TODO: Better validate that inside the ProcessorServer itself + # Note, the address is already validated with the type field host, port = address.split(':') processor_server = ProcessorServer( mongodb_addr=database, processor_name=processor_name, - processor_class=None, # For readability purposes assigned here + processor_class=None # For readability purposes assigned here ) processor_server.run_server(host=host, port=int(port)) except Exception as e: diff --git a/src/ocrd_network/client.py b/src/ocrd_network/client.py index df04c3c27..9fa0b3994 100644 --- a/src/ocrd_network/client.py +++ b/src/ocrd_network/client.py @@ -1,38 +1,37 @@ -import json -import requests +from json import dumps, loads +from requests import post as requests_post +from ocrd_utils import config, getLogger, LOG_FORMAT -from ocrd_utils import config +from .constants import NETWORK_PROTOCOLS # TODO: This is just a conceptual implementation and first try to # trigger further discussions on how this should look like. class Client: def __init__( - self, - server_addr_processing: str = config.OCRD_NETWORK_SERVER_ADDR_PROCESSING, - server_addr_workflow: str = config.OCRD_NETWORK_SERVER_ADDR_WORKFLOW, - server_addr_workspace: str = config.OCRD_NETWORK_SERVER_ADDR_WORKSPACE, + self, + server_addr_processing: str = config.OCRD_NETWORK_SERVER_ADDR_PROCESSING, + server_addr_workflow: str = config.OCRD_NETWORK_SERVER_ADDR_WORKFLOW, + server_addr_workspace: str = config.OCRD_NETWORK_SERVER_ADDR_WORKSPACE ): + self.log = getLogger(f"ocrd_network.client") self.server_addr_processing = server_addr_processing self.server_addr_workflow = server_addr_workflow self.server_addr_workspace = server_addr_workspace def send_processing_request(self, processor_name: str, req_params: dict): verify_server_protocol(self.server_addr_processing) - req_url = f'{self.server_addr_processing}/processor/{processor_name}' + req_url = f"{self.server_addr_processing}/processor/{processor_name}" req_headers = {"Content-Type": "application/json; charset=utf-8"} - req_json = json.loads(json.dumps(req_params)) - - print(f'Sending processing request to: {req_url}') - response = requests.post(url=req_url, headers=req_headers, json=req_json) + req_json = loads(dumps(req_params)) + self.log.info(f"Sending processing request to: {req_url}") + self.log.debug(req_json) + response = requests_post(url=req_url, headers=req_headers, json=req_json) return response.json() def verify_server_protocol(address: str): - protocol_matched = False - for protocol in ['http://', 'https://']: + for protocol in NETWORK_PROTOCOLS: if address.startswith(protocol): - protocol_matched = True - break - if not protocol_matched: - raise ValueError(f'Wrong/Missing protocol in the server address: {address}') + return + raise ValueError(f"Wrong/Missing protocol in the server address: {address}, must be one of: {NETWORK_PROTOCOLS}") diff --git a/src/ocrd_network/constants.py b/src/ocrd_network/constants.py index cbaccd4cf..53dbd9b11 100644 --- a/src/ocrd_network/constants.py +++ b/src/ocrd_network/constants.py @@ -1,2 +1,60 @@ -NETWORK_AGENT_SERVER = 'server' -NETWORK_AGENT_WORKER = 'worker' +from enum import Enum + +DOCKER_IMAGE_MONGO_DB = "mongo" +DOCKER_IMAGE_RABBIT_MQ = "rabbitmq:3.12-management" +# These feature flags are required by default to use the newer version +DOCKER_RABBIT_MQ_FEATURES = "quorum_queue,implicit_default_bindings,classic_mirrored_queue_version" + +NETWORK_PROTOCOLS = ["http://", "https://"] +OCRD_ALL_JSON_TOOLS_URL = "https://ocr-d.de/js/ocrd-all-tool.json" +# Used as a placeholder to lock all pages when no page_id is specified +SERVER_ALL_PAGES_PLACEHOLDER = "all_pages" + + +class AgentType(str, Enum): + PROCESSING_WORKER = "worker" + PROCESSOR_SERVER = "server" + + +class DeployType(str, Enum): + # Deployed by the Processing Server config file + DOCKER = "docker" + NATIVE = "native" + # Deployed through a registration endpoint of the Processing Server + # TODO: That endpoint is still not implemented + EXTERNAL = "external" + + +# TODO: Make the states uppercase +class JobState(str, Enum): + # The processing job is cached inside the Processing Server requests cache + cached = "CACHED" + # The processing job was cancelled due to failed dependencies + cancelled = "CANCELLED" + # Processing job failed + failed = "FAILED" + # The processing job is queued inside the RabbitMQ + queued = "QUEUED" + # Processing job is currently running in a Worker or Processor Server + running = "RUNNING" + # Processing job finished successfully + success = "SUCCESS" + # Processing job has not been assigned yet + unset = "UNSET" + + +class NetworkLoggingDirs(str, Enum): + METS_SERVERS = "mets_servers" + PROCESSING_JOBS = "processing_jobs" + PROCESSING_SERVERS = "processing_servers" + PROCESSING_WORKERS = "processing_workers" + PROCESSOR_SERVERS = "processor_servers" + + +class ServerApiTags(str, Enum): + ADMIN = "admin" + DISCOVERY = "discovery" + PROCESSING = "processing" + TOOLS = "tools" + WORKFLOW = "workflow" + WORKSPACE = "workspace" diff --git a/src/ocrd_network/database.py b/src/ocrd_network/database.py index 946ed95a5..8b0b48925 100644 --- a/src/ocrd_network/database.py +++ b/src/ocrd_network/database.py @@ -2,7 +2,7 @@ Jobs: for every process-request a job is inserted into the database with an uuid, status and information about the process like parameters and file groups. It is mainly used to track the status -(`ocrd_network.models.job.StateEnum`) of a job so that the state of a job can be queried. Finished +(`ocrd_network.constants.JobState`) of a job so that the state of a job can be queried. Finished jobs are not deleted from the database. Workspaces: A job or a processor always runs on a workspace. So a processor needs the information @@ -15,16 +15,13 @@ from beanie import init_beanie from beanie.operators import In from motor.motor_asyncio import AsyncIOMotorClient -from uuid import uuid4 from pathlib import Path +from pymongo import MongoClient, uri_parser as mongo_uri_parser +from re import sub as re_sub from typing import List +from uuid import uuid4 -from .models import ( - DBProcessorJob, - DBWorkflowJob, - DBWorkspace, - DBWorkflowScript, -) +from .models import DBProcessorJob, DBWorkflowJob, DBWorkspace, DBWorkflowScript from .utils import call_sync @@ -94,15 +91,11 @@ async def db_update_workspace(workspace_id: str = None, workspace_mets_path: str if not workspace_id and not workspace_mets_path: raise ValueError(f'Either `workspace_id` or `workspace_mets_path` field must be used as a search key') if workspace_id: - workspace = await DBWorkspace.find_one( - DBWorkspace.workspace_id == workspace_id - ) + workspace = await DBWorkspace.find_one(DBWorkspace.workspace_id == workspace_id) if not workspace: raise ValueError(f'Workspace with id "{workspace_id}" not in the DB.') if workspace_mets_path: - workspace = await DBWorkspace.find_one( - DBWorkspace.workspace_mets_path == workspace_mets_path - ) + workspace = await DBWorkspace.find_one(DBWorkspace.workspace_mets_path == workspace_mets_path) if not workspace: raise ValueError(f'Workspace with path "{workspace_mets_path}" not in the DB.') @@ -215,13 +208,13 @@ async def sync_db_get_workflow_job(job_id: str) -> DBWorkflowJob: return await db_get_workflow_job(job_id) -async def db_get_processing_jobs(job_ids: List[str]) -> [DBProcessorJob]: +async def db_get_processing_jobs(job_ids: List[str]) -> List[DBProcessorJob]: jobs = await DBProcessorJob.find(In(DBProcessorJob.job_id, job_ids)).to_list() return jobs @call_sync -async def sync_db_get_processing_jobs(job_ids: List[str]) -> [DBProcessorJob]: +async def sync_db_get_processing_jobs(job_ids: List[str]) -> List[DBProcessorJob]: return await db_get_processing_jobs(job_ids) @@ -257,3 +250,28 @@ async def db_find_first_workflow_script_by_content(content_hash: str) -> DBWorkf @call_sync async def sync_db_find_first_workflow_script_by_content(workflow_id: str) -> DBWorkflowScript: return await db_get_workflow_script(workflow_id) + + +def verify_database_uri(mongodb_address: str) -> str: + try: + # perform validation check + mongo_uri_parser.parse_uri(uri=mongodb_address, validate=True) + except Exception as error: + raise ValueError(f"The MongoDB address '{mongodb_address}' is in wrong format, {error}") + return mongodb_address + + +def verify_mongodb_available(mongo_url: str) -> None: + """ + # The protocol is intentionally set to HTTP instead of MONGODB! + mongodb_test_url = mongo_url.replace("mongodb", "http") + if is_url_responsive(url=mongodb_test_url, tries=3): + return + raise RuntimeError(f"Verifying connection has failed: {mongodb_test_url}") + """ + + try: + client = MongoClient(mongo_url, serverSelectionTimeoutMS=60000.0) + client.admin.command("ismaster") + except Exception: + raise RuntimeError(f'Cannot connect to MongoDB: {re_sub(r":[^@]+@", ":****@", mongo_url)}') \ No newline at end of file diff --git a/src/ocrd_network/deployer.py b/src/ocrd_network/deployer.py deleted file mode 100644 index ff54e0578..000000000 --- a/src/ocrd_network/deployer.py +++ /dev/null @@ -1,568 +0,0 @@ -""" -Abstraction of the deployment functionality for processors. - -The Processing Server provides the configuration parameters to the Deployer agent. -The Deployer agent runs the RabbitMQ Server, MongoDB and the Processing Hosts. -Each Processing Host may have several Processing Workers. -Each Processing Worker is an instance of an OCR-D processor. -""" -from __future__ import annotations -from typing import Dict, List, Union -from re import search as re_search -from pathlib import Path -import subprocess -from time import sleep - -from ocrd_utils import config, getLogger, safe_filename - -from .constants import NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER -from .deployment_utils import ( - create_docker_client, - DeployType, - verify_mongodb_available, - verify_rabbitmq_available, -) -from .logging import get_mets_server_logging_file_path -from .runtime_data import ( - DataHost, - DataMongoDB, - DataProcessingWorker, - DataProcessorServer, - DataRabbitMQ -) -from .utils import ( - is_mets_server_running, - stop_mets_server, - validate_and_load_config -) - - -class Deployer: - def __init__(self, config_path: str) -> None: - self.log = getLogger('ocrd_network.deployer') - config = validate_and_load_config(config_path) - - self.data_mongo: DataMongoDB = DataMongoDB(config['database']) - self.data_queue: DataRabbitMQ = DataRabbitMQ(config['process_queue']) - self.data_hosts: List[DataHost] = [] - self.internal_callback_url = config.get('internal_callback_url', None) - for config_host in config['hosts']: - self.data_hosts.append(DataHost(config_host)) - self.mets_servers: Dict = {} # {"mets_server_url": "mets_server_pid"} - - # TODO: Reconsider this. - def find_matching_processors( - self, - worker_only: bool = False, - server_only: bool = False, - docker_only: bool = False, - native_only: bool = False, - str_names_only: bool = False, - unique_only: bool = False - ) -> Union[List[str], List[object]]: - """Finds and returns a list of matching data objects of type: - `DataProcessingWorker` and `DataProcessorServer`. - - :py:attr:`worker_only` match only processors with worker status - :py:attr:`server_only` match only processors with server status - :py:attr:`docker_only` match only docker processors - :py:attr:`native_only` match only native processors - :py:attr:`str_only` returns the processor_name instead of data object - :py:attr:`unique_only` remove duplicates from the matches - - `worker_only` and `server_only` are mutually exclusive to each other - `docker_only` and `native_only` are mutually exclusive to each other - `unique_only` is allowed only together with `str_names_only` - """ - - if worker_only and server_only: - raise ValueError(f"Only 'worker_only' or 'server_only' is allowed, not both.") - if docker_only and native_only: - raise ValueError(f"Only 'docker_only' or 'native_only' is allowed, not both.") - if not str_names_only and unique_only: - raise ValueError(f"Value 'unique_only' is allowed only together with 'str_names_only'") - - # Find all matching objects of type: - # DataProcessingWorker or DataProcessorServer - matched_objects = [] - for data_host in self.data_hosts: - if not server_only: - for data_worker in data_host.data_workers: - if data_worker.deploy_type == DeployType.NATIVE and docker_only: - continue - if data_worker.deploy_type == DeployType.DOCKER and native_only: - continue - matched_objects.append(data_worker) - if not worker_only: - for data_server in data_host.data_servers: - if data_server.deploy_type == DeployType.NATIVE and docker_only: - continue - if data_server.deploy_type == DeployType.DOCKER and native_only: - continue - matched_objects.append(data_server) - if str_names_only: - # gets only the processor names of the matched objects - name_list = [match.processor_name for match in matched_objects] - if unique_only: - # removes the duplicates, if any - return list(dict.fromkeys(name_list)) - return name_list - return matched_objects - - def resolve_processor_server_url(self, processor_name) -> str: - processor_server_url = '' - for data_host in self.data_hosts: - for data_server in data_host.data_servers: - if data_server.processor_name == processor_name: - processor_server_url = f'http://{data_host.address}:{data_server.port}/' - return processor_server_url - - def kill_all(self) -> None: - """ kill all started services: hosts, database, queue - - The order of killing is important to optimize graceful shutdown in the future. If RabbitMQ - server is killed before killing Processing Workers, that may have bad outcome and leave - Processing Workers in an unpredictable state - """ - self.kill_hosts() - self.kill_mongodb() - self.kill_rabbitmq() - - def deploy_hosts( - self, - mongodb_url: str, - rabbitmq_url: str - ) -> None: - for host_data in self.data_hosts: - if host_data.needs_ssh: - host_data.create_client(client_type='ssh') - assert host_data.ssh_client - if host_data.needs_docker: - host_data.create_client(client_type='docker') - assert host_data.docker_client - - self.log.debug(f'Deploying processing workers on host: {host_data.address}') - for data_worker in host_data.data_workers: - self._deploy_processing_worker( - mongodb_url, - rabbitmq_url, - host_data, - data_worker - ) - - self.log.debug(f'Deploying processor servers on host: {host_data.address}') - for data_server in host_data.data_servers: - self._deploy_processor_server( - mongodb_url, - host_data, - data_server - ) - - if host_data.ssh_client: - host_data.ssh_client.close() - host_data.ssh_client = None - if host_data.docker_client: - host_data.docker_client.close() - host_data.docker_client = None - - def _deploy_processing_worker( - self, - mongodb_url: str, - rabbitmq_url: str, - host_data: DataHost, - data_worker: DataProcessingWorker - ) -> None: - self.log.debug(f"Deploying processing worker, " - f"environment: '{data_worker.deploy_type}', " - f"name: '{data_worker.processor_name}', " - f"address: '{host_data.address}'") - - if data_worker.deploy_type == DeployType.NATIVE: - assert host_data.ssh_client # to satisfy mypy - pid = self.start_native_processor( - ssh_client=host_data.ssh_client, - processor_name=data_worker.processor_name, - queue_url=rabbitmq_url, - database_url=mongodb_url, - ) - data_worker.pid = pid - elif data_worker.deploy_type == DeployType.DOCKER: - assert host_data.docker_client # to satisfy mypy - pid = self.start_docker_processor( - docker_client=host_data.docker_client, - processor_name=data_worker.processor_name, - _queue_url=rabbitmq_url, - _database_url=mongodb_url - ) - data_worker.pid = pid - sleep(0.2) - - # TODO: Revisit this to remove code duplications of deploy_* methods - def _deploy_processor_server( - self, - mongodb_url: str, - host_data: DataHost, - data_server: DataProcessorServer, - ) -> None: - self.log.debug(f"Deploying processing worker, " - f"environment: '{data_server.deploy_type}', " - f"name: '{data_server.processor_name}', " - f"address: '{data_server.host}:{data_server.port}'") - - if data_server.deploy_type == DeployType.NATIVE: - assert host_data.ssh_client - pid = self.start_native_processor_server( - ssh_client=host_data.ssh_client, - processor_name=data_server.processor_name, - agent_address=f'{data_server.host}:{data_server.port}', - database_url=mongodb_url, - ) - data_server.pid = pid - - if data_server.processor_name in host_data.server_ports: - name = data_server.processor_name - port = data_server.port - if host_data.server_ports[name]: - host_data.server_ports[name] = host_data.server_ports[name].append(port) - else: - host_data.server_ports[name] = [port] - else: - host_data.server_ports[data_server.processor_name] = [data_server.port] - elif data_server.deploy_type == DeployType.DOCKER: - raise Exception("Deploying docker processor server is not supported yet!") - - def deploy_rabbitmq( - self, - image: str, - detach: bool, - remove: bool, - ports_mapping: Union[Dict, None] = None - ) -> str: - if self.data_queue.skip_deployment: - self.log.debug(f"RabbitMQ is externaly managed. Skipping deployment") - verify_rabbitmq_available( - self.data_queue.address, - self.data_queue.port, - self.data_queue.vhost, - self.data_queue.username, - self.data_queue.password - ) - return self.data_queue.url - self.log.debug(f"Trying to deploy '{image}', with modes: " - f"detach='{detach}', remove='{remove}'") - - if not self.data_queue or not self.data_queue.address: - raise ValueError('Deploying RabbitMQ has failed - missing configuration.') - - client = create_docker_client( - self.data_queue.address, - self.data_queue.ssh_username, - self.data_queue.ssh_password, - self.data_queue.ssh_keypath - ) - if not ports_mapping: - # 5672, 5671 - used by AMQP 0-9-1 and AMQP 1.0 clients without and with TLS - # 15672, 15671: HTTP API clients, management UI and rabbitmq admin, without and with TLS - # 25672: used for internode and CLI tools communication and is allocated from - # a dynamic range (limited to a single port by default, computed as AMQP port + 20000) - ports_mapping = { - 5672: self.data_queue.port, - 15672: 15672, - 25672: 25672 - } - res = client.containers.run( - image=image, - detach=detach, - remove=remove, - ports=ports_mapping, - # The default credentials to be used by the processing workers - environment=[ - f'RABBITMQ_DEFAULT_USER={self.data_queue.username}', - f'RABBITMQ_DEFAULT_PASS={self.data_queue.password}' - ] - ) - assert res and res.id, \ - f'Failed to start RabbitMQ docker container on host: {self.data_queue.address}' - self.data_queue.pid = res.id - client.close() - - rmq_host = self.data_queue.address - rmq_port = int(self.data_queue.port) - rmq_vhost = '/' - - verify_rabbitmq_available( - host=rmq_host, - port=rmq_port, - vhost=rmq_vhost, - username=self.data_queue.username, - password=self.data_queue.password - ) - self.log.info(f'The RabbitMQ server was deployed on URL: ' - f'{rmq_host}:{rmq_port}{rmq_vhost}') - return self.data_queue.url - - def deploy_mongodb( - self, - image: str, - detach: bool, - remove: bool, - ports_mapping: Union[Dict, None] = None - ) -> str: - if self.data_mongo.skip_deployment: - self.log.debug('MongoDB is externaly managed. Skipping deployment') - verify_mongodb_available(self.data_mongo.url) - return self.data_mongo.url - - self.log.debug(f"Trying to deploy '{image}', with modes: " - f"detach='{detach}', remove='{remove}'") - - if not self.data_mongo or not self.data_mongo.address: - raise ValueError('Deploying MongoDB has failed - missing configuration.') - - client = create_docker_client( - self.data_mongo.address, - self.data_mongo.ssh_username, - self.data_mongo.ssh_password, - self.data_mongo.ssh_keypath - ) - if not ports_mapping: - ports_mapping = { - 27017: self.data_mongo.port - } - if self.data_mongo.username: - environment = [ - f'MONGO_INITDB_ROOT_USERNAME={self.data_mongo.username}', - f'MONGO_INITDB_ROOT_PASSWORD={self.data_mongo.password}' - ] - else: - environment = [] - - res = client.containers.run( - image=image, - detach=detach, - remove=remove, - ports=ports_mapping, - environment=environment - ) - if not res or not res.id: - raise RuntimeError('Failed to start MongoDB docker container on host: ' - f'{self.data_mongo.address}') - self.data_mongo.pid = res.id - client.close() - - mongodb_hostinfo = f'{self.data_mongo.address}:{self.data_mongo.port}' - self.log.info(f'The MongoDB was deployed on host: {mongodb_hostinfo}') - return self.data_mongo.url - - def kill_rabbitmq(self) -> None: - if self.data_queue.skip_deployment: - return - elif not self.data_queue.pid: - self.log.warning('No running RabbitMQ instance found') - return - client = create_docker_client( - self.data_queue.address, - self.data_queue.ssh_username, - self.data_queue.ssh_password, - self.data_queue.ssh_keypath - ) - client.containers.get(self.data_queue.pid).stop() - self.data_queue.pid = None - client.close() - self.log.info('The RabbitMQ is stopped') - - def kill_mongodb(self) -> None: - if self.data_mongo.skip_deployment: - return - elif not self.data_mongo.pid: - self.log.warning('No running MongoDB instance found') - return - client = create_docker_client( - self.data_mongo.address, - self.data_mongo.ssh_username, - self.data_mongo.ssh_password, - self.data_mongo.ssh_keypath - ) - client.containers.get(self.data_mongo.pid).stop() - self.data_mongo.pid = None - client.close() - self.log.info('The MongoDB is stopped') - - def kill_hosts(self) -> None: - self.log.debug('Starting to kill/stop hosts') - # Kill processing hosts - for host_data in self.data_hosts: - if host_data.needs_ssh: - host_data.create_client(client_type='ssh') - assert host_data.ssh_client - if host_data.needs_docker: - host_data.create_client(client_type='docker') - assert host_data.docker_client - - self.log.debug(f'Killing/Stopping processing workers on host: {host_data.address}') - self.kill_processing_workers(host_data) - - self.log.debug(f'Killing/Stopping processor servers on host: {host_data.address}') - self.kill_processor_servers(host_data) - - if host_data.ssh_client: - host_data.ssh_client.close() - host_data.ssh_client = None - if host_data.docker_client: - host_data.docker_client.close() - host_data.docker_client = None - - # TODO: Optimize the code duplication from start_* and kill_* methods - def kill_processing_workers(self, host_data: DataHost) -> None: - amount = len(host_data.data_workers) - if not amount: - self.log.info(f'No active processing workers to be stopped.') - return - self.log.info(f"Trying to stop {amount} processing workers:") - for worker in host_data.data_workers: - if not worker.pid: - continue - if worker.deploy_type == DeployType.NATIVE: - host_data.ssh_client.exec_command(f'kill {worker.pid}') - self.log.info(f"Stopped native worker with pid: '{worker.pid}'") - elif worker.deploy_type == DeployType.DOCKER: - host_data.docker_client.containers.get(worker.pid).stop() - self.log.info(f"Stopped docker worker with container id: '{worker.pid}'") - host_data.data_workers = [] - - def kill_processor_servers(self, host_data: DataHost) -> None: - amount = len(host_data.data_servers) - if not amount: - self.log.info(f'No active processor servers to be stopped.') - return - self.log.info(f"Trying to stop {amount} processing workers:") - for server in host_data.data_servers: - if not server.pid: - continue - if server.deploy_type == DeployType.NATIVE: - host_data.ssh_client.exec_command(f'kill {server.pid}') - self.log.info(f"Stopped native server with pid: '{server.pid}'") - elif server.deploy_type == DeployType.DOCKER: - host_data.docker_client.containers.get(server.pid).stop() - self.log.info(f"Stopped docker server with container id: '{server.pid}'") - host_data.data_servers = [] - - def start_native_processor( - self, - ssh_client, - processor_name: str, - queue_url: str, - database_url: str - ) -> str: - """ start a processor natively on a host via ssh - - Args: - ssh_client: paramiko SSHClient to execute commands on a host - processor_name: name of processor to run - queue_url: url to rabbitmq - database_url: url to database - - Returns: - str: pid of running process - """ - self.log.info(f'Starting native processing worker: {processor_name}') - channel = ssh_client.invoke_shell() - stdin, stdout = channel.makefile('wb'), channel.makefile('rb') - cmd = f'{processor_name} {NETWORK_AGENT_WORKER} --database {database_url} --queue {queue_url} &' - # the only way (I could find) to make it work to start a process in the background and - # return early is this construction. The pid of the last started background process is - # printed with `echo $!` but it is printed inbetween other output. Because of that I added - # `xyz` before and after the code to easily be able to filter out the pid via regex when - # returning from the function - - self.log.debug(f'About to execute command: {cmd}') - stdin.write(f'{cmd}\n') - stdin.write('echo xyz$!xyz \n exit \n') - output = stdout.read().decode('utf-8') - stdout.close() - stdin.close() - return re_search(r'xyz([0-9]+)xyz', output).group(1) # type: ignore - - def start_docker_processor( - self, - docker_client, - processor_name: str, - _queue_url: str, - _database_url: str - ) -> str: - # TODO: Raise an exception here as well? - # raise Exception("Deploying docker processing worker is not supported yet!") - - self.log.info(f'Starting docker container processor: {processor_name}') - # TODO: add real command here to start processing server in docker here - res = docker_client.containers.run('debian', 'sleep 500s', detach=True, remove=True) - assert res and res.id, f'Running processor: {processor_name} in docker-container failed' - return res.id - - # TODO: Just a copy of the above start_native_processor() method. - # Far from being great... But should be good as a starting point - def start_native_processor_server( - self, - ssh_client, - processor_name: str, - agent_address: str, - database_url: str - ) -> str: - self.log.info(f"Starting native processor server: {processor_name} on {agent_address}") - channel = ssh_client.invoke_shell() - stdin, stdout = channel.makefile('wb'), channel.makefile('rb') - cmd = f'{processor_name} {NETWORK_AGENT_SERVER} --address {agent_address} --database {database_url} &' - self.log.debug(f'About to execute command: {cmd}') - stdin.write(f'{cmd}\n') - stdin.write('echo xyz$!xyz \n exit \n') - output = stdout.read().decode('utf-8') - stdout.close() - stdin.close() - return re_search(r'xyz([0-9]+)xyz', output).group(1) # type: ignore - - # TODO: No support for TCP version yet - def start_unix_mets_server(self, mets_path: str) -> Path: - log_file = get_mets_server_logging_file_path(mets_path=mets_path) - mets_server_url = Path(config.OCRD_NETWORK_SOCKETS_ROOT_DIR, f"{safe_filename(mets_path)}.sock") - - if is_mets_server_running(mets_server_url=str(mets_server_url)): - self.log.info(f"The mets server is already started: {mets_server_url}") - return mets_server_url - - cwd = Path(mets_path).parent - self.log.info(f'Starting UDS mets server: {mets_server_url}') - sub_process = subprocess.Popen( - args=['nohup', 'ocrd', 'workspace', '--mets-server-url', f'{mets_server_url}', - '-d', f'{cwd}', 'server', 'start'], - shell=False, - stdout=open(file=log_file, mode='w'), - stderr=open(file=log_file, mode='a'), - cwd=cwd, - universal_newlines=True - ) - # Wait for the mets server to start - sleep(2) - self.mets_servers[mets_server_url] = sub_process.pid - return mets_server_url - - def stop_unix_mets_server(self, mets_server_url: str) -> None: - self.log.info(f'Stopping UDS mets server: {mets_server_url}') - if Path(mets_server_url) in self.mets_servers: - mets_server_pid = self.mets_servers[Path(mets_server_url)] - else: - raise Exception(f"Mets server not found: {mets_server_url}") - - ''' - subprocess.run( - args=['kill', '-s', 'SIGINT', f'{mets_server_pid}'], - shell=False, - universal_newlines=True - ) - ''' - - # TODO: Reconsider this again - # Not having this sleep here causes connection errors - # on the last request processed by the processing worker. - # Sometimes 3 seconds is enough, sometimes not. - sleep(5) - stop_mets_server(mets_server_url=mets_server_url) diff --git a/src/ocrd_network/logging.py b/src/ocrd_network/logging.py deleted file mode 100644 index 3365e2ddc..000000000 --- a/src/ocrd_network/logging.py +++ /dev/null @@ -1,48 +0,0 @@ -from pathlib import Path -from ocrd_utils import safe_filename, config - -from .constants import NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER - -OCRD_NETWORK_MODULES = [ - "mets_servers", - "processing_jobs", - "processing_servers", - "processing_workers", - "processor_servers" -] - - -def get_root_logging_dir(module_name: str) -> Path: - if module_name not in OCRD_NETWORK_MODULES: - raise ValueError(f"Invalid module name: {module_name}, should be one of: {OCRD_NETWORK_MODULES}") - module_log_dir = Path(config.OCRD_NETWORK_LOGS_ROOT_DIR, module_name) - module_log_dir.mkdir(parents=True, exist_ok=True) - return module_log_dir - - -def get_cache_locked_pages_logging_file_path() -> Path: - return get_root_logging_dir("processing_servers") / "cache_locked_pages.log" - - -def get_cache_processing_requests_logging_file_path() -> Path: - return get_root_logging_dir("processing_servers") / "cache_processing_requests.log" - - -def get_processing_job_logging_file_path(job_id: str) -> Path: - return get_root_logging_dir("processing_jobs") / f"{job_id}.log" - - -def get_processing_server_logging_file_path(pid: int) -> Path: - return get_root_logging_dir("processing_servers") / f"processing_server.{pid}.log" - - -def get_processing_worker_logging_file_path(processor_name: str, pid: int) -> Path: - return get_root_logging_dir("processing_workers") / f"{NETWORK_AGENT_WORKER}.{pid}.{processor_name}.log" - - -def get_processor_server_logging_file_path(processor_name: str, pid: int) -> Path: - return get_root_logging_dir("processor_servers") / f"{NETWORK_AGENT_SERVER}.{pid}.{processor_name}.log" - - -def get_mets_server_logging_file_path(mets_path: str) -> Path: - return get_root_logging_dir("mets_servers") / f"{safe_filename(mets_path)}.log" diff --git a/src/ocrd_network/logging_utils.py b/src/ocrd_network/logging_utils.py new file mode 100644 index 000000000..c20136642 --- /dev/null +++ b/src/ocrd_network/logging_utils.py @@ -0,0 +1,52 @@ +from logging import FileHandler, Formatter, Logger +from pathlib import Path + +from ocrd_utils import config, LOG_FORMAT, safe_filename +from .constants import AgentType, NetworkLoggingDirs + + +def configure_file_handler_with_formatter(logger: Logger, log_file: Path, mode: str = "a") -> None: + file_handler = FileHandler(filename=log_file, mode=mode) + file_handler.setFormatter(Formatter(LOG_FORMAT)) + logger.addHandler(file_handler) + + +def get_root_logging_dir(module_name: NetworkLoggingDirs) -> Path: + module_log_dir = Path(config.OCRD_NETWORK_LOGS_ROOT_DIR, module_name.value) + module_log_dir.mkdir(parents=True, exist_ok=True) + return module_log_dir + + +def get_cache_locked_pages_logging_file_path() -> Path: + log_file: str = "cache_locked_pages.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSING_SERVERS), log_file) + + +def get_cache_processing_requests_logging_file_path() -> Path: + log_file: str = "cache_processing_requests.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSING_SERVERS), log_file) + + +def get_mets_server_logging_file_path(mets_path: str) -> Path: + log_file: str = f"{safe_filename(mets_path)}.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.METS_SERVERS), log_file) + + +def get_processing_job_logging_file_path(job_id: str) -> Path: + log_file: str = f"{job_id}.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSING_JOBS), log_file) + + +def get_processing_server_logging_file_path(pid: int) -> Path: + log_file: str = f"processing_server.{pid}.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSING_SERVERS), log_file) + + +def get_processing_worker_logging_file_path(processor_name: str, pid: int) -> Path: + log_file: str = f"{AgentType.PROCESSING_WORKER}.{pid}.{processor_name}.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSING_WORKERS), log_file) + + +def get_processor_server_logging_file_path(processor_name: str, pid: int) -> Path: + log_file: str = f"{AgentType.PROCESSOR_SERVER}.{pid}.{processor_name}.log" + return Path(get_root_logging_dir(NetworkLoggingDirs.PROCESSOR_SERVERS), log_file) diff --git a/src/ocrd_network/models/__init__.py b/src/ocrd_network/models/__init__.py index dc8231f76..774f8aa13 100644 --- a/src/ocrd_network/models/__init__.py +++ b/src/ocrd_network/models/__init__.py @@ -12,18 +12,10 @@ 'PYJobOutput', 'PYOcrdTool', 'PYResultMessage', - 'PYWorkflowJobOutput', - 'StateEnum', + 'PYWorkflowJobOutput' ] -from .job import ( - DBProcessorJob, - DBWorkflowJob, - PYJobInput, - PYJobOutput, - PYWorkflowJobOutput, - StateEnum -) +from .job import DBProcessorJob, DBWorkflowJob, PYJobInput, PYJobOutput, PYWorkflowJobOutput from .messages import PYResultMessage from .ocrd_tool import PYOcrdTool from .workspace import DBWorkspace diff --git a/src/ocrd_network/models/job.py b/src/ocrd_network/models/job.py index 6cb31bfb9..efc6750c4 100644 --- a/src/ocrd_network/models/job.py +++ b/src/ocrd_network/models/job.py @@ -1,27 +1,8 @@ from beanie import Document from datetime import datetime -from enum import Enum from pydantic import BaseModel from typing import Dict, List, Optional - -from ocrd_network import NETWORK_AGENT_WORKER - - -class StateEnum(str, Enum): - # The processing job is cached inside the Processing Server requests cache - cached = 'CACHED' - # The processing job was cancelled due to failed dependencies - cancelled = 'CANCELLED' - # The processing job is queued inside the RabbitMQ - queued = 'QUEUED' - # Processing job is currently running in a Worker or Processor Server - running = 'RUNNING' - # Processing job finished successfully - success = 'SUCCESS' - # Processing job failed - failed = 'FAILED' - # Processing job has not been assigned yet - unset = 'UNSET' +from ..constants import AgentType, JobState class PYJobInput(BaseModel): @@ -37,9 +18,8 @@ class PYJobInput(BaseModel): parameters: dict = {} # Always set to empty dict when None, otherwise it fails ocr-d-validation result_queue_name: Optional[str] = None callback_url: Optional[str] = None - # Used to toggle between sending requests to 'worker' and 'server', - # i.e., Processing Worker and Processor Server, respectively - agent_type: Optional[str] = NETWORK_AGENT_WORKER + # Used to toggle between sending requests to different network agents + agent_type: AgentType = AgentType.PROCESSING_WORKER # Auto generated by the Processing Server when forwarding to the Processor Server job_id: Optional[str] = None # If set, specifies a list of job ids this job depends on @@ -50,9 +30,10 @@ class Config: 'example': { 'path_to_mets': '/path/to/mets.xml', 'description': 'The description of this execution', - 'input_file_grps': ['INPUT_FILE_GROUP'], - 'output_file_grps': ['OUTPUT_FILE_GROUP'], - 'page_id': 'PAGE_ID', + 'input_file_grps': ['DEFAULT'], + 'output_file_grps': ['OCR-D-BIN'], + 'agent_type': AgentType.PROCESSING_WORKER, + 'page_id': 'PHYS_0001..PHYS_0003', 'parameters': {} } } @@ -63,7 +44,7 @@ class PYJobOutput(BaseModel): """ job_id: str processor_name: str - state: StateEnum + state: JobState = JobState.unset path_to_mets: Optional[str] workspace_id: Optional[str] input_file_grps: List[str] @@ -80,7 +61,7 @@ class DBProcessorJob(Document): path_to_mets: Optional[str] workspace_id: Optional[str] description: Optional[str] - state: StateEnum + state: JobState = JobState.unset input_file_grps: List[str] output_file_grps: Optional[List[str]] page_id: Optional[str] @@ -114,13 +95,13 @@ def to_job_output(self) -> PYJobOutput: class PYWorkflowJobOutput(BaseModel): """ Wraps output information for a workflow job-response """ - job_id: str - page_id: str - page_wise: bool = False # A dictionary where each entry has: # key: page_id # value: List of and processing job ids sorted in dependency order - processing_job_ids: Dict + processing_job_ids: Dict[str, List[str]] + page_id: str + page_wise: bool = False + job_id: str path_to_mets: Optional[str] workspace_id: Optional[str] description: Optional[str] diff --git a/src/ocrd_network/models/messages.py b/src/ocrd_network/models/messages.py index 062f1a9a2..de3e06687 100644 --- a/src/ocrd_network/models/messages.py +++ b/src/ocrd_network/models/messages.py @@ -1,22 +1,22 @@ from pydantic import BaseModel from typing import Optional -from .job import StateEnum +from .job import JobState class PYResultMessage(BaseModel): """ Wraps the parameters required to make a result message request """ job_id: str - state: StateEnum + state: JobState = JobState.unset path_to_mets: Optional[str] = None workspace_id: Optional[str] = None class Config: schema_extra = { - 'example': { - 'job_id': '123123123', - 'state': 'SUCCESS', - 'path_to_mets': '/path/to/mets.xml', - 'workspace_id': 'c7f25615-fc17-4365-a74d-ad20e1ddbd0e' + "example": { + "job_id": "d8e36726-ed28-5476-b83c-bc31d2eecf1f", + "state": JobState.success, + "path_to_mets": "/path/to/mets.xml", + "workspace_id": "c7f25615-fc17-4365-a74d-ad20e1ddbd0e" } } diff --git a/src/ocrd_network/param_validators.py b/src/ocrd_network/param_validators.py index 87cfeee72..27658c048 100644 --- a/src/ocrd_network/param_validators.py +++ b/src/ocrd_network/param_validators.py @@ -1,45 +1,46 @@ from click import ParamType -from .utils import ( - verify_database_uri, - verify_and_parse_mq_uri -) +from .database import verify_database_uri +from .rabbitmq_utils import verify_and_parse_mq_uri class ServerAddressParamType(ParamType): - name = 'Server address string format' - expected_format = 'host:port' + name = "Server address string format" + expected_format = "host:port" def convert(self, value, param, ctx): try: elements = value.split(':') if len(elements) != 2: - raise ValueError('The processing server address is in wrong format') + raise ValueError("The processing server address is in wrong format") int(elements[1]) # validate port except ValueError as error: - self.fail(f'{error}, expected format: {self.expected_format}', param, ctx) + message = f"Expected format: {self.expected_format}, error: {error}" + self.fail(message, param, ctx) return value class QueueServerParamType(ParamType): - name = 'Message queue server string format' + name = "Message queue server string format" def convert(self, value, param, ctx): try: # perform validation check only verify_and_parse_mq_uri(value) except Exception as error: - self.fail(f'{error}', param, ctx) + message = f"Failed to validate the RabbitMQ address, error: {error}" + self.fail(message, param, ctx) return value class DatabaseParamType(ParamType): - name = 'Database string format' + name = "Database string format" def convert(self, value, param, ctx): try: # perform validation check only verify_database_uri(value) except Exception as error: - self.fail(f'{error}', param, ctx) + message = f"Failed to validate the MongoDB address, error: {error}" + self.fail(message, param, ctx) return value diff --git a/src/ocrd_network/process_helpers.py b/src/ocrd_network/process_helpers.py index 04dcd17d0..590ed13f4 100644 --- a/src/ocrd_network/process_helpers.py +++ b/src/ocrd_network/process_helpers.py @@ -1,35 +1,32 @@ from contextlib import nullcontext -import json +from json import dumps +from pathlib import Path from typing import List, Optional from ocrd.processor.helpers import run_cli, run_processor from ocrd_utils import redirect_stderr_and_stdout_to_file, initLogging - from .utils import get_ocrd_workspace_instance # A wrapper for run_processor() and run_cli() def invoke_processor( - processor_class, - executable: str, - abs_path_to_mets: str, - input_file_grps: List[str], - output_file_grps: List[str], - page_id: str, - parameters: dict, - mets_server_url: Optional[str] = None, - log_filename: str = None, + processor_class, + executable: str, + abs_path_to_mets: str, + input_file_grps: List[str], + output_file_grps: List[str], + page_id: str, + parameters: dict, + mets_server_url: Optional[str] = None, + log_filename: Optional[Path] = None, + log_level: str = "DEBUG" ) -> None: if not (processor_class or executable): - raise ValueError('Missing processor class and executable') + raise ValueError("Missing processor class and executable") input_file_grps_str = ','.join(input_file_grps) output_file_grps_str = ','.join(output_file_grps) - workspace = get_ocrd_workspace_instance( - mets_path=abs_path_to_mets, - mets_server_url=mets_server_url - ) - + workspace = get_ocrd_workspace_instance(mets_path=abs_path_to_mets, mets_server_url=mets_server_url) if processor_class: ctx_mgr = redirect_stderr_and_stdout_to_file(log_filename) if log_filename else nullcontext() with ctx_mgr: @@ -44,10 +41,10 @@ def invoke_processor( parameter=parameters, instance_caching=True, mets_server_url=mets_server_url, - log_level='DEBUG' + log_level=log_level ) - except Exception as e: - raise RuntimeError(f"Python executable '{processor_class.__dict__}' exited with: {e}") + except Exception as error: + raise RuntimeError(f"Python executable '{processor_class.__dict__}', error: {error}") else: return_code = run_cli( executable=executable, @@ -56,10 +53,10 @@ def invoke_processor( input_file_grp=input_file_grps_str, output_file_grp=output_file_grps_str, page_id=page_id, - parameter=json.dumps(parameters), + parameter=dumps(parameters), mets_server_url=mets_server_url, - log_level='DEBUG', + log_level=log_level, log_filename=log_filename ) if return_code != 0: - raise RuntimeError(f"CLI executable '{executable}' exited with: {return_code}") + raise RuntimeError(f"CLI executable '{executable}' exited with code: {return_code}") diff --git a/src/ocrd_network/processing_server.py b/src/ocrd_network/processing_server.py index 3c09e8a10..36c838aeb 100644 --- a/src/ocrd_network/processing_server.py +++ b/src/ocrd_network/processing_server.py @@ -1,46 +1,26 @@ from datetime import datetime -from hashlib import md5 -import httpx -import json -from logging import FileHandler, Formatter from os import getpid -from pathlib import Path -import requests from typing import Dict, List, Union -from urllib.parse import urljoin -import uvicorn - -from fastapi import ( - FastAPI, - status, - Request, - HTTPException, - UploadFile, - File, -) +from uvicorn import run as uvicorn_run + +from fastapi import APIRouter, FastAPI, File, HTTPException, Request, status, UploadFile from fastapi.exceptions import RequestValidationError from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse -from pika.exceptions import ChannelClosedByBroker -from ocrd import Resolver, Workspace from ocrd.task_sequence import ProcessorTask -from ocrd_utils import initLogging, getLogger, LOG_FORMAT - -from .constants import NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER +from ocrd_utils import initLogging, getLogger +from .constants import AgentType, JobState, OCRD_ALL_JSON_TOOLS_URL, ServerApiTags from .database import ( initiate_database, - db_create_workspace, db_get_processing_job, db_get_processing_jobs, - db_get_workflow_job, - db_get_workspace, db_update_processing_job, db_update_workspace, db_get_workflow_script, db_find_first_workflow_script_by_content ) -from .deployer import Deployer -from .logging import get_processing_server_logging_file_path +from .runtime_data import Deployer +from .logging_utils import configure_file_handler_with_formatter, get_processing_server_logging_file_path from .models import ( DBProcessorJob, DBWorkflowJob, @@ -48,24 +28,39 @@ PYJobInput, PYJobOutput, PYResultMessage, - PYWorkflowJobOutput, - StateEnum + PYWorkflowJobOutput +) +from .rabbitmq_utils import ( + check_if_queue_exists, + connect_rabbitmq_publisher, + create_message_queues, + OcrdProcessingMessage ) -from .rabbitmq_utils import RMQPublisher, OcrdProcessingMessage from .server_cache import CacheLockedPages, CacheProcessingRequests from .server_utils import ( + create_processing_message, + create_workspace_if_not_exists, + forward_job_to_processor_server, _get_processor_job, _get_processor_job_log, - expand_page_ids, + get_page_ids_list, + get_workflow_content, + get_from_database_workspace, + get_from_database_workflow_job, + parse_workflow_tasks, + raise_http_exception, + request_processor_server_tool_json, validate_and_return_mets_path, - validate_job_input + validate_first_task_input_file_groups_existence, + validate_job_input, + validate_workflow ) from .utils import ( download_ocrd_all_tool_json, - generate_created_time, + expand_page_ids, generate_id, - get_ocrd_workspace_physical_pages, - validate_workflow, + generate_workflow_content, + generate_workflow_content_hash ) @@ -82,37 +77,43 @@ class ProcessingServer(FastAPI): def __init__(self, config_path: str, host: str, port: int) -> None: initLogging() + self.title = "OCR-D Processing Server" super().__init__( + title=self.title, on_startup=[self.on_startup], on_shutdown=[self.on_shutdown], - title='OCR-D Processing Server', - description='OCR-D Processing Server' + description="OCR-D Processing Server" ) - self.log = getLogger('ocrd_network.processing_server') + self.log = getLogger("ocrd_network.processing_server") log_file = get_processing_server_logging_file_path(pid=getpid()) - file_handler = FileHandler(filename=log_file, mode='a') - file_handler.setFormatter(Formatter(LOG_FORMAT)) - self.log.addHandler(file_handler) + configure_file_handler_with_formatter(self.log, log_file=log_file, mode="a") self.log.info(f"Downloading ocrd all tool json") - self.ocrd_all_tool_json = download_ocrd_all_tool_json( - ocrd_all_url="https://ocr-d.de/js/ocrd-all-tool.json" - ) + self.ocrd_all_tool_json = download_ocrd_all_tool_json(ocrd_all_url=OCRD_ALL_JSON_TOOLS_URL) self.hostname = host self.port = port # The deployer is used for: # - deploying agents when the Processing Server is started # - retrieving runtime data of agents self.deployer = Deployer(config_path) + # Used by processing workers and/or processor servers to report back the results + if self.deployer.internal_callback_url: + host = self.deployer.internal_callback_url + self.internal_job_callback_url = f"{host.rstrip('/')}/result_callback" + else: + self.internal_job_callback_url = f"http://{host}:{port}/result_callback" + self.mongodb_url = None - # TODO: Combine these under a single URL, rabbitmq_utils needs an update - self.rmq_host = self.deployer.data_queue.address - self.rmq_port = self.deployer.data_queue.port - self.rmq_vhost = '/' - self.rmq_username = self.deployer.data_queue.username - self.rmq_password = self.deployer.data_queue.password - - # Gets assigned when `connect_publisher` is called on the working object + self.rabbitmq_url = None + self.rmq_data = { + "host": self.deployer.data_queue.host, + "port": self.deployer.data_queue.port, + "vhost": "/", + "username": self.deployer.data_queue.cred_username, + "password": self.deployer.data_queue.cred_password + } + + # Gets assigned when `connect_rabbitmq_publisher()` is called on the working object self.rmq_publisher = None # Used for keeping track of cached processing requests @@ -121,190 +122,184 @@ def __init__(self, config_path: str, host: str, port: int) -> None: # Used for keeping track of locked/unlocked pages of a workspace self.cache_locked_pages = CacheLockedPages() - # Used by processing workers and/or processor servers to report back the results - if self.deployer.internal_callback_url: - host = self.deployer.internal_callback_url - self.internal_job_callback_url = f'{host.rstrip("/")}/result_callback' - else: - self.internal_job_callback_url = f'http://{host}:{port}/result_callback' + self.add_api_routes_others() + self.add_api_routes_processing() + self.add_api_routes_workflow() + + @self.exception_handler(RequestValidationError) + async def validation_exception_handler(request: Request, exc: RequestValidationError): + exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ') + self.log.error(f'{request}: {exc_str}') + content = {'status_code': 10422, 'message': exc_str, 'data': None} + return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + + def start(self) -> None: + """ deploy agents (db, queue, workers) and start the processing server with uvicorn + """ + try: + self.rabbitmq_url = self.deployer.deploy_rabbitmq() + self.mongodb_url = self.deployer.deploy_mongodb() + + # The RMQPublisher is initialized and a connection to the RabbitMQ is performed + self.rmq_publisher = connect_rabbitmq_publisher(self.log, self.rmq_data, enable_acks=True) + + queue_names = self.deployer.find_matching_network_agents( + worker_only=True, str_names_only=True, unique_only=True + ) + self.log.debug(f"Creating message queues on RabbitMQ instance url: {self.rabbitmq_url}") + create_message_queues(logger=self.log, rmq_publisher=self.rmq_publisher, queue_names=queue_names) - self.router.add_api_route( - path='/', + self.deployer.deploy_network_agents(mongodb_url=self.mongodb_url, rabbitmq_url=self.rabbitmq_url) + except Exception as error: + self.log.exception(f"Failed to start the Processing Server, error: {error}") + self.log.warning("Trying to stop previously deployed services and network agents.") + self.deployer.stop_all() + raise + uvicorn_run(self, host=self.hostname, port=int(self.port)) + + async def on_startup(self): + await initiate_database(db_url=self.mongodb_url) + + async def on_shutdown(self) -> None: + """ + - hosts and pids should be stored somewhere + - ensure queue is empty or processor is not currently running + - connect to hosts and kill pids + """ + await self.stop_deployed_agents() + + def add_api_routes_others(self): + others_router = APIRouter() + others_router.add_api_route( + path="/", endpoint=self.home_page, - methods=['GET'], + methods=["GET"], status_code=status.HTTP_200_OK, - summary='Get information about the processing server' + summary="Get information about the processing server" ) - # Create routes - self.router.add_api_route( - path='/stop', + others_router.add_api_route( + path="/stop", endpoint=self.stop_deployed_agents, - methods=['POST'], - tags=['tools'], - summary='Stop database, queue and processing-workers', + methods=["POST"], + tags=[ServerApiTags.TOOLS], + summary="Stop database, queue and processing-workers" ) + self.include_router(others_router) - self.router.add_api_route( - path='/processor/run/{processor_name}', - endpoint=self.push_processor_job, - methods=['POST'], - tags=['processing'], + def add_api_routes_processing(self): + processing_router = APIRouter() + processing_router.add_api_route( + path="/processor", + endpoint=self.list_processors, + methods=["GET"], + tags=[ServerApiTags.PROCESSING, ServerApiTags.DISCOVERY], + status_code=status.HTTP_200_OK, + summary="Get a list of all available processors" + ) + processing_router.add_api_route( + path="/processor/info/{processor_name}", + endpoint=self.get_network_agent_ocrd_tool, + methods=["GET"], + tags=[ServerApiTags.PROCESSING, ServerApiTags.DISCOVERY], status_code=status.HTTP_200_OK, - summary='Submit a job to this processor', + summary="Get information about this processor" + ) + processing_router.add_api_route( + path="/processor/run/{processor_name}", + endpoint=self.validate_and_forward_job_to_network_agent, + methods=["POST"], + tags=[ServerApiTags.PROCESSING], + status_code=status.HTTP_200_OK, + summary="Submit a job to this processor", response_model=PYJobOutput, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/processor/job/{job_id}', + processing_router.add_api_route( + path="/processor/job/{job_id}", endpoint=self.get_processor_job, - methods=['GET'], - tags=['processing'], + methods=["GET"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Get information about a job based on its ID', + summary="Get information about a job based on its ID", response_model=PYJobOutput, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/processor/log/{job_id}', + processing_router.add_api_route( + path="/processor/log/{job_id}", endpoint=self.get_processor_job_log, - methods=['GET'], - tags=['processing'], + methods=["GET"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Get the log file of a job id' + summary="Get the log file of a job id" ) - - self.router.add_api_route( - path='/result_callback', - endpoint=self.remove_from_request_cache, - methods=['POST'], - tags=['processing'], + processing_router.add_api_route( + path="/result_callback", + endpoint=self.remove_job_from_request_cache, + methods=["POST"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Callback used by a worker or processor server for reporting result of a processing request', + summary="Callback used by a worker or processor server for reporting result of a processing request" ) + self.include_router(processing_router) - self.router.add_api_route( - path='/processor/info/{processor_name}', - endpoint=self.get_processor_info, - methods=['GET'], - tags=['processing', 'discovery'], + def add_api_routes_workflow(self): + workflow_router = APIRouter() + workflow_router.add_api_route( + path="/workflow", + endpoint=self.upload_workflow, + methods=["POST"], + tags=[ServerApiTags.WORKFLOW], + status_code=status.HTTP_201_CREATED, + summary="Upload/Register a new workflow script" + ) + workflow_router.add_api_route( + path="/workflow/{workflow_id}", + endpoint=self.download_workflow, + methods=["GET"], + tags=[ServerApiTags.WORKFLOW], status_code=status.HTTP_200_OK, - summary='Get information about this processor', + summary="Download a workflow script" ) - - self.router.add_api_route( - path='/processor', - endpoint=self.list_processors, - methods=['GET'], - tags=['processing', 'discovery'], + workflow_router.add_api_route( + path="/workflow/{workflow_id}", + endpoint=self.replace_workflow, + methods=["PUT"], + tags=[ServerApiTags.WORKFLOW], status_code=status.HTTP_200_OK, - summary='Get a list of all available processors', + summary="Update/Replace a workflow script" ) - - self.router.add_api_route( - path='/workflow/run', + workflow_router.add_api_route( + path="/workflow/run", endpoint=self.run_workflow, - methods=['POST'], - tags=['workflow', 'processing'], + methods=["POST"], + tags=[ServerApiTags.WORKFLOW, ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Run a workflow', + summary="Run a workflow", response_model=PYWorkflowJobOutput, - response_model_exclude=["processing_job_ids"], response_model_exclude_defaults=True, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/workflow/job/{workflow_job_id}', - endpoint=self.get_workflow_info, - methods=['GET'], - tags=['workflow', 'processing'], - status_code=status.HTTP_200_OK, - summary='Get information about a workflow run', - ) - - self.router.add_api_route( - path='/workflow/job-simple/{workflow_job_id}', + workflow_router.add_api_route( + path="/workflow/job-simple/{workflow_job_id}", endpoint=self.get_workflow_info_simple, - methods=['GET'], - tags=['workflow', 'processing'], - status_code=status.HTTP_200_OK, - summary='Get simplified overall job status', - ) - - self.router.add_api_route( - path='/workflow', - endpoint=self.upload_workflow, - methods=['POST'], - tags=['workflow'], - status_code=status.HTTP_201_CREATED, - summary='Upload/Register a new workflow script', - ) - self.router.add_api_route( - path='/workflow/{workflow_id}', - endpoint=self.replace_workflow, - methods=['PUT'], - tags=['workflow'], + methods=["GET"], + tags=[ServerApiTags.WORKFLOW, ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Update/Replace a workflow script', + summary="Get simplified overall job status" ) - self.router.add_api_route( - path='/workflow/{workflow_id}', - endpoint=self.download_workflow, - methods=['GET'], - tags=['workflow'], + workflow_router.add_api_route( + path="/workflow/job/{workflow_job_id}", + endpoint=self.get_workflow_info, + methods=["GET"], + tags=[ServerApiTags.WORKFLOW, ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Download a workflow script', + summary="Get information about a workflow run" ) - - @self.exception_handler(RequestValidationError) - async def validation_exception_handler(request: Request, exc: RequestValidationError): - exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ') - self.log.error(f'{request}: {exc_str}') - content = {'status_code': 10422, 'message': exc_str, 'data': None} - return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - - def start(self) -> None: - """ deploy agents (db, queue, workers) and start the processing server with uvicorn - """ - try: - self.deployer.deploy_rabbitmq(image='rabbitmq:3-management', detach=True, remove=True) - rabbitmq_url = self.deployer.data_queue.url - - self.deployer.deploy_mongodb(image='mongo', detach=True, remove=True) - self.mongodb_url = self.deployer.data_mongo.url - - # The RMQPublisher is initialized and a connection to the RabbitMQ is performed - self.connect_publisher() - self.log.debug(f'Creating message queues on RabbitMQ instance url: {rabbitmq_url}') - self.create_message_queues() - - self.deployer.deploy_hosts( - mongodb_url=self.mongodb_url, - rabbitmq_url=rabbitmq_url - ) - except Exception: - self.log.error('Error during startup of processing server. ' - 'Trying to kill parts of incompletely deployed service') - self.deployer.kill_all() - raise - uvicorn.run(self, host=self.hostname, port=int(self.port)) - - async def on_startup(self): - await initiate_database(db_url=self.mongodb_url) - - async def on_shutdown(self) -> None: - """ - - hosts and pids should be stored somewhere - - ensure queue is empty or processor is not currently running - - connect to hosts and kill pids - """ - await self.stop_deployed_agents() + self.include_router(workflow_router) async def home_page(self): message = f"The home page of the {self.title}" @@ -315,194 +310,90 @@ async def home_page(self): return json_message async def stop_deployed_agents(self) -> None: - self.deployer.kill_all() - - def connect_publisher(self, enable_acks: bool = True) -> None: - self.log.info(f'Connecting RMQPublisher to RabbitMQ server: ' - f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}') - self.rmq_publisher = RMQPublisher( - host=self.rmq_host, - port=self.rmq_port, - vhost=self.rmq_vhost - ) - self.log.debug(f'RMQPublisher authenticates with username: ' - f'{self.rmq_username}, password: {self.rmq_password}') - self.rmq_publisher.authenticate_and_connect( - username=self.rmq_username, - password=self.rmq_password - ) - if enable_acks: - self.rmq_publisher.enable_delivery_confirmations() - self.log.info('Delivery confirmations are enabled') - self.log.info('Successfully connected RMQPublisher.') - - def create_message_queues(self) -> None: - """ Create the message queues based on the occurrence of - `workers.name` in the config file. - """ + self.deployer.stop_all() + + def query_ocrd_tool_json_from_server(self, processor_name: str) -> Dict: + processor_server_base_url = self.deployer.resolve_processor_server_url(processor_name) + if processor_server_base_url == '': + message = f"Processor Server URL of '{processor_name}' not found" + raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message=message) + return request_processor_server_tool_json(self.log, processor_server_base_url=processor_server_base_url) + + async def get_network_agent_ocrd_tool( + self, processor_name: str, agent_type: AgentType = AgentType.PROCESSING_WORKER + ) -> Dict: + ocrd_tool = {} + error_message = f"Network agent of type '{agent_type}' for processor '{processor_name}' not found." + if agent_type != AgentType.PROCESSING_WORKER and agent_type != AgentType.PROCESSOR_SERVER: + message = f"Unknown agent type: {agent_type}, {type(agent_type)}" + raise_http_exception(self.log, status_code=status.HTTP_501_NOT_IMPLEMENTED, message=message) + if agent_type == AgentType.PROCESSING_WORKER: + ocrd_tool = self.ocrd_all_tool_json.get(processor_name, None) + if agent_type == AgentType.PROCESSOR_SERVER: + ocrd_tool = self.query_ocrd_tool_json_from_server(processor_name) + if not ocrd_tool: + raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, error_message) + return ocrd_tool - # The abstract version of the above lines - queue_names = self.deployer.find_matching_processors( - worker_only=True, - str_names_only=True, - unique_only=True - ) + def network_agent_exists_server(self, processor_name: str) -> bool: + processor_server_url = self.deployer.resolve_processor_server_url(processor_name) + return bool(processor_server_url) + def network_agent_exists_worker(self, processor_name: str) -> bool: # TODO: Reconsider and refactor this. # Added ocrd-dummy by default if not available for the integration tests. - # A proper Processing Worker / Processor Server registration endpoint is needed on the Processing Server side - if 'ocrd-dummy' not in queue_names: - queue_names.append('ocrd-dummy') - - for queue_name in queue_names: - # The existence/validity of the worker.name is not tested. - # Even if an ocr-d processor does not exist, the queue is created - self.log.info(f'Creating a message queue with id: {queue_name}') - self.rmq_publisher.create_queue(queue_name=queue_name) - - @staticmethod - def create_processing_message(job: DBProcessorJob) -> OcrdProcessingMessage: - processing_message = OcrdProcessingMessage( - job_id=job.job_id, - processor_name=job.processor_name, - created_time=generate_created_time(), - path_to_mets=job.path_to_mets, - workspace_id=job.workspace_id, - input_file_grps=job.input_file_grps, - output_file_grps=job.output_file_grps, - page_id=job.page_id, - parameters=job.parameters, - result_queue_name=job.result_queue_name, - callback_url=job.callback_url, - internal_callback_url=job.internal_callback_url - ) - return processing_message - - def check_if_queue_exists(self, processor_name) -> bool: - try: - # Only checks if the process queue exists, if not raises ChannelClosedByBroker - self.rmq_publisher.create_queue(processor_name, passive=True) + # A proper Processing Worker / Processor Server registration endpoint + # is needed on the Processing Server side + if processor_name == 'ocrd-dummy': return True - except ChannelClosedByBroker as error: - self.log.warning(f"Process queue with id '{processor_name}' not existing: {error}") - # TODO: Revisit when reconnection strategy is implemented - # Reconnect publisher, i.e., restore the connection - not efficient, but works - self.connect_publisher(enable_acks=True) - return False - - def query_ocrd_tool_json_from_server(self, processor_server_url: str): - # Request the tool json from the Processor Server - response = requests.get( - urljoin(processor_server_url, 'info'), - headers={"Content-Type": "application/json"} - ) - if not response.status_code == 200: - msg = f"Failed to retrieve ocrd tool json from: {processor_server_url}, status code: {response.status_code}" - self.log.exception(msg) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=msg - ) - ocrd_tool = response.json() - return ocrd_tool - - def processing_agent_exists(self, processor_name: str, agent_type: str) -> bool: - if agent_type not in [NETWORK_AGENT_SERVER, NETWORK_AGENT_WORKER]: - return False - if agent_type == NETWORK_AGENT_WORKER: - # TODO: Reconsider and refactor this. - # Added ocrd-dummy by default if not available for the integration tests. - # A proper Processing Worker / Processor Server registration endpoint - # is needed on the Processing Server side - if processor_name == 'ocrd-dummy': - return True - if not self.check_if_queue_exists(processor_name): - return False - if agent_type == NETWORK_AGENT_SERVER: - processor_server_url = self.deployer.resolve_processor_server_url(processor_name) - if not processor_server_url: - return False - return True - - async def get_processing_agent_ocrd_tool(self, processor_name: str, agent_type: str) -> dict: - ocrd_tool = {} - if agent_type == NETWORK_AGENT_WORKER: - ocrd_tool = await self.get_processor_info(processor_name) - if agent_type == NETWORK_AGENT_SERVER: - processor_server_url = self.deployer.resolve_processor_server_url(processor_name) - ocrd_tool = self.query_ocrd_tool_json_from_server(processor_server_url) - return ocrd_tool + return bool(check_if_queue_exists(self.log, self.rmq_data, processor_name=processor_name)) + + def validate_agent_type_and_existence(self, processor_name: str, agent_type: AgentType) -> None: + agent_exists = False + if agent_type == AgentType.PROCESSOR_SERVER: + agent_exists = self.network_agent_exists_server(processor_name=processor_name) + elif agent_type == AgentType.PROCESSING_WORKER: + agent_exists = self.network_agent_exists_worker(processor_name=processor_name) + else: + message = f"Unknown agent type: {agent_type}, {type(agent_type)}" + raise_http_exception(self.log, status_code=status.HTTP_501_NOT_IMPLEMENTED, message=message) + if not agent_exists: + message = f"Network agent of type '{agent_type}' for processor '{processor_name}' not found." + raise_http_exception(self.log, status.HTTP_422_UNPROCESSABLE_ENTITY, message) - async def push_processor_job(self, processor_name: str, data: PYJobInput) -> PYJobOutput: + async def validate_and_forward_job_to_network_agent(self, processor_name: str, data: PYJobInput) -> PYJobOutput: + # Append the processor name to the request itself + data.processor_name = processor_name + self.validate_agent_type_and_existence(processor_name=data.processor_name, agent_type=data.agent_type) if data.job_id: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Job id field is set but must not be: {data.job_id}" - ) - if not data.workspace_id and not data.path_to_mets: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="either 'path_to_mets' or 'workspace_id' must be provided" - ) + message = f"Processing request job id field is set but must not be: {data.job_id}" + raise_http_exception(self.log, status.HTTP_422_UNPROCESSABLE_ENTITY, message) # Generate processing job id data.job_id = generate_id() - # Append the processor name to the request itself - data.processor_name = processor_name - - # Check if the processing agent (worker/server) exists (is deployed) - if not self.processing_agent_exists(data.processor_name, data.agent_type): - msg = f"Agent of type '{data.agent_type}' does not exist for '{data.processor_name}'" - self.log.exception(msg) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=msg - ) - - ocrd_tool = await self.get_processing_agent_ocrd_tool( + ocrd_tool = await self.get_network_agent_ocrd_tool( processor_name=data.processor_name, agent_type=data.agent_type ) - if not ocrd_tool: - msg = f"Agent of type '{data.agent_type}' does not exist for '{data.processor_name}'" - self.log.exception(msg) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=msg - ) - validate_job_input(self.log, data.processor_name, ocrd_tool, data) if data.workspace_id: - try: - db_workspace = await db_get_workspace(workspace_id=data.workspace_id) - except ValueError as error: - self.log.exception(error) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workspace with id `{data.workspace_id}` not found in the DB." - ) + # just a check whether the workspace exists in the database or not + await get_from_database_workspace(self.log, data.workspace_id) else: # data.path_to_mets provided instead - try: - # TODO: Reconsider and refactor this. Core cannot create workspaces by api, but processing-server needs - # the workspace in the database. Here the workspace is created if the path is available locally and - # not existing in the DB - since it has not been uploaded through the Workspace Server. - await db_create_workspace(data.path_to_mets) - except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f"Mets file not existing: {data.path_to_mets}") + await create_workspace_if_not_exists(self.log, mets_path=data.path_to_mets) workspace_key = data.path_to_mets if data.path_to_mets else data.workspace_id # initialize the request counter for the workspace_key self.cache_processing_requests.update_request_counter(workspace_key=workspace_key, by_value=0) - # Since the path is not resolved yet, - # the return value is not important for the Processing Server + # This check is done to return early in case a workspace_id is provided + # but the abs mets path cannot be queried from the DB request_mets_path = await validate_and_return_mets_path(self.log, data) page_ids = expand_page_ids(data.page_id) # A flag whether the current request must be cached - # This is set to true if for any output fileGrp there + # This is set to true if for any output file group there # is a page_id value that has been previously locked cache_current_request = False @@ -528,7 +419,7 @@ async def push_processor_job(self, processor_name: str, data: PYJobInput) -> PYJ db_cached_job = DBProcessorJob( **data.dict(exclude_unset=True, exclude_none=True), internal_callback_url=self.internal_job_callback_url, - state=StateEnum.cached + state=JobState.cached ) await db_cached_job.insert() return db_cached_job.to_job_output() @@ -554,84 +445,49 @@ async def push_processor_job(self, processor_name: str, data: PYJobInput) -> PYJ db_queued_job = DBProcessorJob( **data.dict(exclude_unset=True, exclude_none=True), internal_callback_url=self.internal_job_callback_url, - state=StateEnum.queued + state=JobState.queued ) await db_queued_job.insert() self.cache_processing_requests.update_request_counter(workspace_key=workspace_key, by_value=1) - job_output = await self.push_to_processing_agent(data=data, db_job=db_queued_job) + job_output = await self.push_job_to_network_agent(data=data, db_job=db_queued_job) return job_output - async def push_to_processing_agent(self, data: PYJobInput, db_job: DBProcessorJob) -> PYJobOutput: - if data.agent_type == NETWORK_AGENT_WORKER: - processing_message = self.create_processing_message(db_job) - self.log.debug(f"Pushing to processing worker: {data.processor_name}, {data.page_id}, {data.job_id}") - await self.push_to_processing_queue(data.processor_name, processing_message) - job_output = db_job.to_job_output() - else: # data.agent_type == NETWORK_AGENT_SERVER - self.log.debug(f"Pushing to processor server: {data.processor_name}, {data.page_id}, {data.job_id}") - job_output = await self.push_to_processor_server(data.processor_name, data) + async def push_job_to_network_agent(self, data: PYJobInput, db_job: DBProcessorJob) -> PYJobOutput: + if data.agent_type != AgentType.PROCESSING_WORKER and data.agent_type != AgentType.PROCESSOR_SERVER: + message = f"Unknown agent type: {data.agent_type}, {type(data.agent_type)}" + raise_http_exception(self.log, status_code=status.HTTP_501_NOT_IMPLEMENTED, message=message) + job_output = None + self.log.debug(f"Pushing to {data.agent_type}: {data.processor_name}, {data.page_id}, {data.job_id}") + if data.agent_type == AgentType.PROCESSING_WORKER: + job_output = await self.push_job_to_processing_queue(db_job=db_job) + if data.agent_type == AgentType.PROCESSOR_SERVER: + job_output = await self.push_job_to_processor_server(job_input=data) if not job_output: - self.log.exception('Failed to create job output') - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail='Failed to create job output' - ) + message = f"Failed to create job output for job input: {data}" + raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message) return job_output - async def push_to_processing_queue(self, processor_name: str, processing_message: OcrdProcessingMessage): + async def push_job_to_processing_queue(self, db_job: DBProcessorJob) -> PYJobOutput: if not self.rmq_publisher: - raise Exception('RMQPublisher is not connected') + message = "The Processing Server has no connection to RabbitMQ Server. RMQPublisher is not connected." + raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message) + processing_message = create_processing_message(self.log, db_job) try: - self.rmq_publisher.publish_to_queue( - queue_name=processor_name, - message=OcrdProcessingMessage.encode_yml(processing_message) - ) + encoded_message = OcrdProcessingMessage.encode_yml(processing_message) + self.rmq_publisher.publish_to_queue(queue_name=db_job.processor_name, message=encoded_message) except Exception as error: - self.log.exception(f'RMQPublisher has failed: {error}') - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f'RMQPublisher has failed: {error}' - ) - - async def push_to_processor_server( - self, - processor_name: str, - job_input: PYJobInput - ) -> PYJobOutput: - try: - json_data = json.dumps(job_input.dict(exclude_unset=True, exclude_none=True)) - except Exception as e: - msg = f"Failed to json dump the PYJobInput, error: {e}" - self.log.exception(msg) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=msg - ) - - processor_server_url = self.deployer.resolve_processor_server_url(processor_name) - - # TODO: The amount of pages should come as a request input - # TODO: cf https://github.com/OCR-D/core/pull/1030/files#r1152551161 - # currently, use 200 as a default - amount_of_pages = 200 - request_timeout = 20.0 * amount_of_pages # 20 sec timeout per page - # Post a processing job to the Processor Server asynchronously - timeout = httpx.Timeout(timeout=request_timeout, connect=30.0) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - urljoin(processor_server_url, 'run'), - headers={'Content-Type': 'application/json'}, - json=json.loads(json_data) + message = ( + f"Processing server has failed to push processing message to queue: {db_job.processor_name}, " + f"Processing message: {processing_message.__dict__}" ) + raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message, error) + return db_job.to_job_output() - if not response.status_code == 202: - self.log.exception(f"Failed to post '{processor_name}' job to: {processor_server_url}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to post '{processor_name}' job to: {processor_server_url}" - ) - job_output = response.json() - return job_output + async def push_job_to_processor_server(self, job_input: PYJobInput) -> PYJobOutput: + processor_server_base_url = self.deployer.resolve_processor_server_url(job_input.processor_name) + return await forward_job_to_processor_server( + self.log, job_input=job_input, processor_server_base_url=processor_server_base_url + ) async def get_processor_job(self, job_id: str) -> PYJobOutput: return await _get_processor_job(self.log, job_id) @@ -639,57 +495,74 @@ async def get_processor_job(self, job_id: str) -> PYJobOutput: async def get_processor_job_log(self, job_id: str) -> FileResponse: return await _get_processor_job_log(self.log, job_id) - async def remove_from_request_cache(self, result_message: PYResultMessage): - result_job_id = result_message.job_id - result_job_state = result_message.state - path_to_mets = result_message.path_to_mets - workspace_id = result_message.workspace_id - self.log.debug(f"Result job_id: {result_job_id}, state: {result_job_state}") + async def _lock_pages_of_workspace( + self, workspace_key: str, output_file_grps: List[str], page_ids: List[str] + ) -> None: + # Lock the output file group pages for the current request + self.cache_locked_pages.lock_pages( + workspace_key=workspace_key, + output_file_grps=output_file_grps, + page_ids=page_ids + ) - # Read DB workspace entry - db_workspace = await db_get_workspace(workspace_id=workspace_id, workspace_mets_path=path_to_mets) - if not db_workspace: - self.log.exception(f"Workspace with id: {workspace_id} or path: {path_to_mets} not found in DB") - mets_server_url = db_workspace.mets_server_url - workspace_key = path_to_mets if path_to_mets else workspace_id + async def _unlock_pages_of_workspace( + self, workspace_key: str, output_file_grps: List[str], page_ids: List[str] + ) -> None: + self.cache_locked_pages.unlock_pages( + workspace_key=workspace_key, + output_file_grps=output_file_grps, + page_ids=page_ids + ) - if result_job_state == StateEnum.failed: - await self.cache_processing_requests.cancel_dependent_jobs( + async def push_cached_jobs_to_agents(self, processing_jobs: List[PYJobInput]) -> None: + if not len(processing_jobs): + self.log.debug("No processing jobs were consumed from the requests cache") + return + for data in processing_jobs: + self.log.info(f"Changing the job status of: {data.job_id} from {JobState.cached} to {JobState.queued}") + db_consumed_job = await db_update_processing_job(job_id=data.job_id, state=JobState.queued) + workspace_key = data.path_to_mets if data.path_to_mets else data.workspace_id + + # Lock the output file group pages for the current request + await self._lock_pages_of_workspace( workspace_key=workspace_key, - processing_job_id=result_job_id + output_file_grps=data.output_file_grps, + page_ids=expand_page_ids(data.page_id) ) - if result_job_state != StateEnum.success: - # TODO: Handle other potential error cases - pass - - db_result_job = await db_get_processing_job(result_job_id) - if not db_result_job: - self.log.exception(f"Processing job with id: {result_job_id} not found in DB") + self.cache_processing_requests.update_request_counter(workspace_key=workspace_key, by_value=1) + job_output = await self.push_job_to_network_agent(data=data, db_job=db_consumed_job) + if not job_output: + self.log.exception(f"Failed to create job output for job input data: {data}") - # Unlock the output file group pages for the result processing request - self.cache_locked_pages.unlock_pages( + async def _cancel_cached_dependent_jobs(self, workspace_key: str, job_id: str) -> None: + await self.cache_processing_requests.cancel_dependent_jobs( workspace_key=workspace_key, - output_file_grps=db_result_job.output_file_grps, - page_ids=expand_page_ids(db_result_job.page_id) + processing_job_id=job_id ) - # Take the next request from the cache (if any available) + async def _consume_cached_jobs_of_workspace( + self, workspace_key: str, mets_server_url: str + ) -> List[PYJobInput]: + + # Check whether the internal queue for the workspace key still exists if workspace_key not in self.cache_processing_requests.processing_requests: self.log.debug(f"No internal queue available for workspace with key: {workspace_key}") - return + return [] - # decrease the internal counter by 1 - request_counter = self.cache_processing_requests.update_request_counter(workspace_key=workspace_key, by_value=-1) - self.log.debug(f"Internal processing counter value: {request_counter}") + # decrease the internal cache counter by 1 + request_counter = self.cache_processing_requests.update_request_counter( + workspace_key=workspace_key, by_value=-1 + ) + self.log.debug(f"Internal processing job cache counter value: {request_counter}") if not len(self.cache_processing_requests.processing_requests[workspace_key]): if request_counter <= 0: # Shut down the Mets Server for the workspace_key since no # more internal callbacks are expected for that workspace self.log.debug(f"Stopping the mets server: {mets_server_url}") self.deployer.stop_unix_mets_server(mets_server_url=mets_server_url) - # The queue is empty - delete it try: + # The queue is empty - delete it del self.cache_processing_requests.processing_requests[workspace_key] except KeyError: self.log.warning(f"Trying to delete non-existing internal queue with key: {workspace_key}") @@ -697,75 +570,72 @@ async def remove_from_request_cache(self, result_message: PYResultMessage): # For debugging purposes it is good to see if any locked pages are left self.log.debug(f"Contents of the locked pages cache for: {workspace_key}") locked_pages = self.cache_locked_pages.get_locked_pages(workspace_key=workspace_key) - for output_fileGrp in locked_pages: - self.log.debug(f"{output_fileGrp}: {locked_pages[output_fileGrp]}") + for output_file_grp in locked_pages: + self.log.debug(f"{output_file_grp}: {locked_pages[output_file_grp]}") else: self.log.debug(f"Internal request cache is empty but waiting for {request_counter} result callbacks.") - return - + return [] consumed_requests = await self.cache_processing_requests.consume_cached_requests(workspace_key=workspace_key) + return consumed_requests - if not len(consumed_requests): - self.log.debug("No processing jobs were consumed from the requests cache") - return + async def remove_job_from_request_cache(self, result_message: PYResultMessage): + result_job_id = result_message.job_id + result_job_state = result_message.state + path_to_mets = result_message.path_to_mets + workspace_id = result_message.workspace_id + self.log.info(f"Result job_id: {result_job_id}, state: {result_job_state}") - for data in consumed_requests: - self.log.debug(f"Changing the job status of: {data.job_id} from {StateEnum.cached} to {StateEnum.queued}") - db_consumed_job = await db_update_processing_job(job_id=data.job_id, state=StateEnum.queued) - workspace_key = data.path_to_mets if data.path_to_mets else data.workspace_id + db_workspace = await get_from_database_workspace(self.log, workspace_id, path_to_mets) + mets_server_url = db_workspace.mets_server_url + workspace_key = path_to_mets if path_to_mets else workspace_id - # Lock the output file group pages for the current request - self.cache_locked_pages.lock_pages( - workspace_key=workspace_key, - output_file_grps=data.output_file_grps, - page_ids=expand_page_ids(data.page_id) - ) - self.cache_processing_requests.update_request_counter(workspace_key=workspace_key, by_value=1) - job_output = await self.push_to_processing_agent(data=data, db_job=db_consumed_job) - if not job_output: - self.log.exception(f'Failed to create job output for job input data: {data}') + if result_job_state == JobState.failed: + await self._cancel_cached_dependent_jobs(workspace_key, result_job_id) - async def get_processor_info(self, processor_name) -> Dict: - """ Return a processor's ocrd-tool.json - """ - ocrd_tool = self.ocrd_all_tool_json.get(processor_name, None) - if not ocrd_tool: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Ocrd tool JSON of '{processor_name}' not available!" + if result_job_state != JobState.success: + # TODO: Handle other potential error cases + pass + + try: + db_result_job = await db_get_processing_job(result_job_id) + # Unlock the output file group pages for the result processing request + await self._unlock_pages_of_workspace( + workspace_key=workspace_key, + output_file_grps=db_result_job.output_file_grps, + page_ids=expand_page_ids(db_result_job.page_id) ) + except ValueError as error: + message = f"Processing result job with id '{result_job_id}' not found in the DB." + raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message, error) - # TODO: Returns the ocrd tool json even of processors - # that are not deployed. This may or may not be desired. - return ocrd_tool + consumed_cached_jobs = await self._consume_cached_jobs_of_workspace( + workspace_key=workspace_key, mets_server_url=mets_server_url + ) + await self.push_cached_jobs_to_agents(processing_jobs=consumed_cached_jobs) async def list_processors(self) -> List[str]: # There is no caching on the Processing Server side - processor_names_list = self.deployer.find_matching_processors( - docker_only=False, - native_only=False, - worker_only=False, - server_only=False, - str_names_only=True, - unique_only=True + processor_names_list = self.deployer.find_matching_network_agents( + docker_only=False, native_only=False, worker_only=False, server_only=False, + str_names_only=True, unique_only=True ) return processor_names_list async def task_sequence_to_processing_jobs( - self, - tasks: List[ProcessorTask], - mets_path: str, - page_id: str, - agent_type: NETWORK_AGENT_WORKER, + self, + tasks: List[ProcessorTask], + mets_path: str, + page_id: str, + agent_type: AgentType = AgentType.PROCESSING_WORKER ) -> List[PYJobOutput]: - file_group_cache = {} + temp_file_group_cache = {} responses = [] for task in tasks: # Find dependent jobs of the current task dependent_jobs = [] for input_file_grp in task.input_file_grps: - if input_file_grp in file_group_cache: - dependent_jobs.append(file_group_cache[input_file_grp]) + if input_file_grp in temp_file_group_cache: + dependent_jobs.append(temp_file_group_cache[input_file_grp]) # NOTE: The `task.mets_path` and `task.page_id` is not utilized in low level # Thus, setting these two flags in the ocrd process workflow file has no effect job_input_data = PYJobInput( @@ -778,97 +648,64 @@ async def task_sequence_to_processing_jobs( agent_type=agent_type, depends_on=dependent_jobs, ) - response = await self.push_processor_job( + response = await self.validate_and_forward_job_to_network_agent( processor_name=job_input_data.processor_name, data=job_input_data ) for file_group in task.output_file_grps: - file_group_cache[file_group] = response.job_id + temp_file_group_cache[file_group] = response.job_id responses.append(response) return responses - async def run_workflow( - self, - mets_path: str, - workflow: Union[UploadFile, None] = File(None), - workflow_id: str = None, - agent_type: str = NETWORK_AGENT_WORKER, - page_id: str = None, - page_wise: bool = False, - workflow_callback_url: str = None - ) -> PYWorkflowJobOutput: - try: - # core cannot create workspaces by api, but processing-server needs the workspace in the - # database. Here the workspace is created if the path available and not existing in db: - await db_create_workspace(mets_path) - except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f"Mets file not existing: {mets_path}") - - if not workflow: - if not workflow_id: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Either workflow or workflow_id must be provided") + def validate_tasks_agents_existence(self, tasks: List[ProcessorTask], agent_type: AgentType) -> None: + missing_agents = [] + for task in tasks: try: - workflow = await db_get_workflow_script(workflow_id) - except ValueError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow with id '{workflow_id}' not found") - workflow = workflow.content - else: - workflow = (await workflow.read()).decode("utf-8") + self.validate_agent_type_and_existence(processor_name=task.executable, agent_type=agent_type) + except HTTPException: + # catching the error is not relevant here + missing_agents.append({task.executable, agent_type}) + if missing_agents: + message = ( + "Workflow validation has failed. The desired network agents not found. " + f"Missing processing agents: {missing_agents}" + ) + raise_http_exception(self.log, status.HTTP_406_NOT_ACCEPTABLE, message) - try: - tasks_list = workflow.splitlines() - tasks = [ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()] - except ValueError as e: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Error parsing tasks: {e}") + async def run_workflow( + self, + mets_path: str, + workflow: Union[UploadFile, None] = File(None), + workflow_id: str = None, + agent_type: AgentType = AgentType.PROCESSING_WORKER, + page_id: str = None, + page_wise: bool = False, + workflow_callback_url: str = None + ) -> PYWorkflowJobOutput: + await create_workspace_if_not_exists(self.log, mets_path=mets_path) + workflow_content = await get_workflow_content(self.log, workflow_id, workflow) + processing_tasks = parse_workflow_tasks(self.log, workflow_content) # Validate the input file groups of the first task in the workflow - available_groups = Workspace(Resolver(), Path(mets_path).parents[0]).mets.file_groups - for grp in tasks[0].input_file_grps: - if grp not in available_groups: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Input file grps of 1st processor not found: {tasks[0].input_file_grps}" - ) + validate_first_task_input_file_groups_existence(self.log, mets_path, processing_tasks[0].input_file_grps) # Validate existence of agents (processing workers/processor servers) # for the ocr-d processors referenced inside tasks - missing_agents = [] - for task in tasks: - if not self.processing_agent_exists(processor_name=task.executable, agent_type=agent_type): - missing_agents.append({task.executable, agent_type}) - if missing_agents: - raise HTTPException( - status_code=status.HTTP_406_NOT_ACCEPTABLE, - detail=f"Workflow validation has failed. Processing agents not found: {missing_agents}. " - f"Make sure the desired processors are deployed either as a processing " - f"worker or processor server" - ) + self.validate_tasks_agents_existence(processing_tasks, agent_type) - try: - if page_id: - page_range = expand_page_ids(page_id) - else: - # If no page_id is specified, all physical pages are assigned as page range - page_range = get_ocrd_workspace_physical_pages(mets_path=mets_path) - compact_page_range = f'{page_range[0]}..{page_range[-1]}' - except BaseException as e: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Error determining page-range: {e}") + page_ids = get_page_ids_list(self.log, mets_path, page_id) + + # TODO: Reconsider this, the compact page range may not always work if the page_ids are hashes! + compact_page_range = f"{page_ids[0]}..{page_ids[-1]}" if not page_wise: responses = await self.task_sequence_to_processing_jobs( - tasks=tasks, + tasks=processing_tasks, mets_path=mets_path, page_id=compact_page_range, agent_type=agent_type ) - processing_job_ids = [] - for response in responses: - processing_job_ids.append(response.job_id) + processing_job_ids = [response.job_id for response in responses] db_workflow_job = DBWorkflowJob( job_id=generate_id(), page_id=compact_page_range, @@ -881,16 +718,14 @@ async def run_workflow( return db_workflow_job.to_job_output() all_pages_job_ids = {} - for current_page in page_range: + for current_page in page_ids: responses = await self.task_sequence_to_processing_jobs( - tasks=tasks, + tasks=processing_tasks, mets_path=mets_path, page_id=current_page, agent_type=agent_type ) - processing_job_ids = [] - for response in responses: - processing_job_ids.append(response.job_id) + processing_job_ids = [response.job_id for response in responses] all_pages_job_ids[current_page] = processing_job_ids db_workflow_job = DBWorkflowJob( job_id=generate_id(), @@ -903,114 +738,102 @@ async def run_workflow( await db_workflow_job.insert() return db_workflow_job.to_job_output() - async def get_workflow_info(self, workflow_job_id) -> Dict: - """ Return list of a workflow's processor jobs - """ - try: - workflow_job = await db_get_workflow_job(workflow_job_id) - except ValueError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow-Job with id: {workflow_job_id} not found") - job_ids: List[str] = [id for lst in workflow_job.processing_job_ids.values() for id in lst] - jobs = await db_get_processing_jobs(job_ids) - res = {} + @staticmethod + def _produce_workflow_status_response(processing_jobs: List[DBProcessorJob]) -> Dict: + response = {} failed_tasks = {} failed_tasks_key = "failed-processor-tasks" - for job in jobs: - res.setdefault(job.processor_name, {}) - res[job.processor_name].setdefault(job.state.value, 0) - res[job.processor_name][job.state.value] += 1 - if job.state == "FAILED": - if failed_tasks_key not in res: - res[failed_tasks_key] = failed_tasks - failed_tasks.setdefault(job.processor_name, []) - failed_tasks[job.processor_name].append({ - "job_id": job.job_id, - "page_id": job.page_id, - }) - return res + for p_job in processing_jobs: + response.setdefault(p_job.processor_name, {}) + response[p_job.processor_name].setdefault(p_job.state.value, 0) + response[p_job.processor_name][p_job.state.value] += 1 + if p_job.state == JobState.failed: + if failed_tasks_key not in response: + response[failed_tasks_key] = failed_tasks + failed_tasks.setdefault(p_job.processor_name, []) + failed_tasks[p_job.processor_name].append( + {"job_id": p_job.job_id, "page_id": p_job.page_id} + ) + return response - """ - Simplified version of the `get_workflow_info` that returns a single state for the entire workflow. - - If a single processing job fails, the entire workflow job status is set to FAILED. - - If there are any processing jobs running, regardless of other states, such as QUEUED and CACHED, - the entire workflow job status is set to RUNNING. - - If all processing jobs has finished successfully, only then the workflow job status is set to SUCCESS - """ - async def get_workflow_info_simple(self, workflow_job_id) -> Dict[str, StateEnum]: + @staticmethod + def _produce_workflow_status_simple_response(processing_jobs: List[DBProcessorJob]) -> JobState: + workflow_job_state = JobState.unset + success_jobs = 0 + for p_job in processing_jobs: + if p_job.state == JobState.cached or p_job.state == JobState.queued: + continue + if p_job.state == JobState.failed or p_job.state == JobState.cancelled: + workflow_job_state = JobState.failed + break + if p_job.state == JobState.running: + workflow_job_state = JobState.running + if p_job.state == JobState.success: + success_jobs += 1 + if len(processing_jobs) == success_jobs: + workflow_job_state = JobState.success + return workflow_job_state + + async def get_workflow_info(self, workflow_job_id) -> Dict: """ Return list of a workflow's processor jobs """ - try: - workflow_job = await db_get_workflow_job(workflow_job_id) - except ValueError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow-Job with id: {workflow_job_id} not found") + workflow_job = await get_from_database_workflow_job(self.log, workflow_job_id) job_ids: List[str] = [job_id for lst in workflow_job.processing_job_ids.values() for job_id in lst] jobs = await db_get_processing_jobs(job_ids) + response = self._produce_workflow_status_response(processing_jobs=jobs) + return response - workflow_job_state = StateEnum.unset - success_jobs = 0 - for job in jobs: - if job.state == StateEnum.cached or job.state == StateEnum.queued: - continue - if job.state == StateEnum.failed or job.state == StateEnum.cancelled: - workflow_job_state = StateEnum.failed - break - if job.state == StateEnum.running: - workflow_job_state = StateEnum.running - if job.state == StateEnum.success: - success_jobs += 1 - # if all jobs succeeded - if len(job_ids) == success_jobs: - workflow_job_state = StateEnum.success + async def get_workflow_info_simple(self, workflow_job_id) -> Dict[str, JobState]: + """ + Simplified version of the `get_workflow_info` that returns a single state for the entire workflow. + - If a single processing job fails, the entire workflow job status is set to FAILED. + - If there are any processing jobs running, regardless of other states, such as QUEUED and CACHED, + the entire workflow job status is set to RUNNING. + - If all processing jobs has finished successfully, only then the workflow job status is set to SUCCESS + """ + workflow_job = await get_from_database_workflow_job(self.log, workflow_job_id) + job_ids: List[str] = [job_id for lst in workflow_job.processing_job_ids.values() for job_id in lst] + jobs = await db_get_processing_jobs(job_ids) + workflow_job_state = self._produce_workflow_status_simple_response(processing_jobs=jobs) return {"state": workflow_job_state} - async def upload_workflow(self, workflow: UploadFile) -> Dict: + async def upload_workflow(self, workflow: UploadFile) -> Dict[str, str]: """ Store a script for a workflow in the database """ - workflow_id = generate_id() - content = (await workflow.read()).decode("utf-8") - if not validate_workflow(content): - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Provided workflow script is invalid") - - content_hash = md5(content.encode("utf-8")).hexdigest() + workflow_content = await generate_workflow_content(workflow) + validate_workflow(self.log, workflow_content) + content_hash = generate_workflow_content_hash(workflow_content) try: db_workflow_script = await db_find_first_workflow_script_by_content(content_hash) if db_workflow_script: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="The same workflow" - f"-script exists with id '{db_workflow_script.workflow_id}'") + message = f"The same workflow script already exists, workflow id: {db_workflow_script.workflow_id}" + raise_http_exception(self.log, status.HTTP_409_CONFLICT, message) except ValueError: pass - + workflow_id = generate_id() db_workflow_script = DBWorkflowScript( workflow_id=workflow_id, - content=content, - content_hash=content_hash, + content=workflow_content, + content_hash=content_hash ) await db_workflow_script.insert() return {"workflow_id": workflow_id} - async def replace_workflow(self, workflow_id, workflow: UploadFile) -> str: + async def replace_workflow(self, workflow_id, workflow: UploadFile) -> Dict[str, str]: """ Update a workflow script file in the database """ - content = (await workflow.read()).decode("utf-8") - if not validate_workflow(content): - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Provided workflow script is invalid") try: db_workflow_script = await db_get_workflow_script(workflow_id) - db_workflow_script.content = content - content_hash = md5(content.encode("utf-8")).hexdigest() + workflow_content = await generate_workflow_content(workflow) + validate_workflow(self.log, workflow_content) + db_workflow_script.content = workflow_content + content_hash = generate_workflow_content_hash(workflow_content) db_workflow_script.content_hash = content_hash - except ValueError as e: - self.log.exception(f"Workflow with id '{workflow_id}' not existing, error: {e}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow-script with id '{workflow_id}' not existing" - ) - await db_workflow_script.save() - return db_workflow_script.workflow_id + await db_workflow_script.save() + return {"workflow_id": db_workflow_script.workflow_id} + except ValueError as error: + message = f"Workflow script not existing for id '{workflow_id}'." + raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message, error) async def download_workflow(self, workflow_id) -> PlainTextResponse: """ Load workflow-script from the database @@ -1018,9 +841,6 @@ async def download_workflow(self, workflow_id) -> PlainTextResponse: try: workflow = await db_get_workflow_script(workflow_id) return PlainTextResponse(workflow.content) - except ValueError as e: - self.log.exception(f"Workflow with id '{workflow_id}' not existing, error: {e}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Workflow-script with id '{workflow_id}' not existing" - ) + except ValueError as error: + message = f"Workflow script not existing for id '{workflow_id}'." + raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message, error) diff --git a/src/ocrd_network/processing_worker.py b/src/ocrd_network/processing_worker.py index ff4287c84..a352ea5fd 100644 --- a/src/ocrd_network/processing_worker.py +++ b/src/ocrd_network/processing_worker.py @@ -9,60 +9,44 @@ """ from datetime import datetime -from logging import FileHandler, Formatter from os import getpid -from time import sleep -import pika.spec -import pika.adapters.blocking_connection -from pika.exceptions import AMQPConnectionError +from pika import BasicProperties +from pika.adapters.blocking_connection import BlockingChannel +from pika.spec import Basic -from ocrd_utils import config, getLogger, LOG_FORMAT -from .database import ( - sync_initiate_database, - sync_db_get_workspace, - sync_db_update_processing_job, -) -from .logging import ( +from ocrd_utils import getLogger +from .constants import JobState +from .database import sync_initiate_database, sync_db_get_workspace, sync_db_update_processing_job, verify_database_uri +from .logging_utils import ( + configure_file_handler_with_formatter, get_processing_job_logging_file_path, - get_processing_worker_logging_file_path + get_processing_worker_logging_file_path, ) -from .models import StateEnum from .process_helpers import invoke_processor from .rabbitmq_utils import ( + connect_rabbitmq_consumer, + connect_rabbitmq_publisher, OcrdProcessingMessage, OcrdResultMessage, - RMQConsumer, - RMQPublisher -) -from .utils import ( - calculate_execution_time, - post_to_callback_url, - verify_database_uri, verify_and_parse_mq_uri ) +from .utils import calculate_execution_time, post_to_callback_url class ProcessingWorker: def __init__(self, rabbitmq_addr, mongodb_addr, processor_name, ocrd_tool: dict, processor_class=None) -> None: self.log = getLogger(f'ocrd_network.processing_worker') log_file = get_processing_worker_logging_file_path(processor_name=processor_name, pid=getpid()) - file_handler = FileHandler(filename=log_file, mode='a') - file_handler.setFormatter(Formatter(LOG_FORMAT)) - self.log.addHandler(file_handler) + configure_file_handler_with_formatter(self.log, log_file=log_file, mode="a") try: verify_database_uri(mongodb_addr) self.log.debug(f'Verified MongoDB URL: {mongodb_addr}') - rmq_data = verify_and_parse_mq_uri(rabbitmq_addr) - self.rmq_username = rmq_data['username'] - self.rmq_password = rmq_data['password'] - self.rmq_host = rmq_data['host'] - self.rmq_port = rmq_data['port'] - self.rmq_vhost = rmq_data['vhost'] - self.log.debug(f'Verified RabbitMQ Credentials: {self.rmq_username}:{self.rmq_password}') - self.log.debug(f'Verified RabbitMQ Server URL: {self.rmq_host}:{self.rmq_port}{self.rmq_vhost}') - except ValueError as e: - raise ValueError(e) + self.rmq_data = verify_and_parse_mq_uri(rabbitmq_addr) + except ValueError as error: + msg = f"Failed to parse data, error: {error}" + self.log.exception(msg) + raise ValueError(msg) sync_initiate_database(mongodb_addr) # Database client self.ocrd_tool = ocrd_tool @@ -75,107 +59,95 @@ def __init__(self, rabbitmq_addr, mongodb_addr, processor_name, ocrd_tool: dict, # Used to consume OcrdProcessingMessage from the queue with name {processor_name} self.rmq_consumer = None # Gets assigned when the `connect_publisher` is called on the worker object - # The publisher is connected when the `result_queue` field of the OcrdProcessingMessage is set for first time # Used to publish OcrdResultMessage type message to the queue with name {processor_name}-result self.rmq_publisher = None + + def connect_consumer(self): + self.rmq_consumer = connect_rabbitmq_consumer(self.log, self.rmq_data) # Always create a queue (idempotent) - self.create_queue() + self.rmq_consumer.create_queue(queue_name=self.processor_name) - def connect_consumer(self) -> None: - self.log.info(f'Connecting RMQConsumer to RabbitMQ server: ' - f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}') - self.rmq_consumer = RMQConsumer( - host=self.rmq_host, - port=self.rmq_port, - vhost=self.rmq_vhost - ) - self.log.debug(f'RMQConsumer authenticates with username: ' - f'{self.rmq_username}, password: {self.rmq_password}') - self.rmq_consumer.authenticate_and_connect( - username=self.rmq_username, - password=self.rmq_password - ) - self.log.info(f'Successfully connected RMQConsumer.') - - def connect_publisher(self, enable_acks: bool = True) -> None: - self.log.info(f'Connecting RMQPublisher to RabbitMQ server: ' - f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}') - self.rmq_publisher = RMQPublisher( - host=self.rmq_host, - port=self.rmq_port, - vhost=self.rmq_vhost - ) - self.log.debug(f'RMQPublisher authenticates with username: ' - f'{self.rmq_username}, password: {self.rmq_password}') - self.rmq_publisher.authenticate_and_connect( - username=self.rmq_username, - password=self.rmq_password - ) - if enable_acks: - self.rmq_publisher.enable_delivery_confirmations() - self.log.info('Delivery confirmations are enabled') - self.log.info('Successfully connected RMQPublisher.') + def connect_publisher(self, enable_acks: bool = True): + self.rmq_publisher = connect_rabbitmq_publisher(self.log, self.rmq_data, enable_acks=enable_acks) # Define what happens every time a message is consumed # from the queue with name self.processor_name def on_consumed_message( - self, - channel: pika.adapters.blocking_connection.BlockingChannel, - delivery: pika.spec.Basic.Deliver, - properties: pika.spec.BasicProperties, - body: bytes) -> None: + self, + channel: BlockingChannel, + delivery: Basic.Deliver, + properties: BasicProperties, + body: bytes + ) -> None: consumer_tag = delivery.consumer_tag delivery_tag: int = delivery.delivery_tag is_redelivered: bool = delivery.redelivered message_headers: dict = properties.headers - self.log.debug(f'Consumer tag: {consumer_tag}, ' - f'message delivery tag: {delivery_tag}, ' - f'redelivered: {is_redelivered}') - self.log.debug(f'Message headers: {message_headers}') + ack_message = f"Acking message with tag: {delivery_tag}" + nack_message = f"Nacking processing message with tag: {delivery_tag}" + + self.log.debug( + f"Consumer tag: {consumer_tag}" + f", message delivery tag: {delivery_tag}" + f", redelivered: {is_redelivered}" + ) + self.log.debug(f"Message headers: {message_headers}") try: - self.log.debug(f'Trying to decode processing message with tag: {delivery_tag}') + self.log.debug(f"Trying to decode processing message with tag: {delivery_tag}") processing_message: OcrdProcessingMessage = OcrdProcessingMessage.decode_yml(body) - except Exception as e: - self.log.error(f'Failed to decode processing message body: {body}') - self.log.error(f'Nacking processing message with tag: {delivery_tag}') + except Exception as error: + msg = f"Failed to decode processing message with tag: {delivery_tag}, error: {error}" + self.log.exception(msg) + self.log.info(nack_message) channel.basic_nack(delivery_tag=delivery_tag, multiple=False, requeue=False) - raise Exception(f'Failed to decode processing message with tag: {delivery_tag}, reason: {e}') + raise Exception(msg) try: - self.log.info(f'Starting to process the received message: {processing_message.__dict__}') + self.log.info(f"Starting to process the received message: {processing_message.__dict__}") self.process_message(processing_message=processing_message) - except Exception as e: - self.log.error(f'Failed to process processing message with tag: {delivery_tag}') - self.log.error(f'Nacking processing message with tag: {delivery_tag}') + except Exception as error: + message = ( + f"Failed to process message with tag: {delivery_tag}. " + f"Processing message: {processing_message.__dict__}" + ) + self.log.exception(f"{message}, error: {error}") + self.log.info(nack_message) channel.basic_nack(delivery_tag=delivery_tag, multiple=False, requeue=False) - raise Exception(f'Failed to process processing message with tag: {delivery_tag}, reason: {e}') + raise Exception(message) - self.log.info(f'Successfully processed RabbitMQ message') - self.log.debug(f'Acking message with tag: {delivery_tag}') + self.log.info(f"Successfully processed RabbitMQ message") + self.log.debug(ack_message) channel.basic_ack(delivery_tag=delivery_tag, multiple=False) def start_consuming(self) -> None: if self.rmq_consumer: - self.log.info(f'Configuring consuming from queue: {self.processor_name}') + self.log.info(f"Configuring consuming from queue: {self.processor_name}") self.rmq_consumer.configure_consuming( queue_name=self.processor_name, callback_method=self.on_consumed_message ) - self.log.info(f'Starting consuming from queue: {self.processor_name}') + self.log.info(f"Starting consuming from queue: {self.processor_name}") # Starting consuming is a blocking action self.rmq_consumer.start_consuming() else: - raise Exception('The RMQConsumer is not connected/configured properly') + msg = f"The RMQConsumer is not connected/configured properly." + self.log.exception(msg) + raise Exception(msg) # TODO: Better error handling required to catch exceptions def process_message(self, processing_message: OcrdProcessingMessage) -> None: # Verify that the processor name in the processing message # matches the processor name of the current processing worker if self.processor_name != processing_message.processor_name: - raise ValueError(f'Processor name is not matching. Expected: {self.processor_name},' - f'Got: {processing_message.processor_name}') + message = ( + "Processor name is not matching. " + f"Expected: {self.processor_name}, " + f"Got: {processing_message.processor_name}" + ) + self.log.exception(message) + raise ValueError(message) # All of this is needed because the OcrdProcessingMessage object # may not contain certain keys. Simply passing None in the OcrdProcessingMessage constructor @@ -183,31 +155,29 @@ def process_message(self, processing_message: OcrdProcessingMessage) -> None: pm_keys = processing_message.__dict__.keys() job_id = processing_message.job_id input_file_grps = processing_message.input_file_grps - output_file_grps = processing_message.output_file_grps if 'output_file_grps' in pm_keys else None - path_to_mets = processing_message.path_to_mets if 'path_to_mets' in pm_keys else None - workspace_id = processing_message.workspace_id if 'workspace_id' in pm_keys else None - page_id = processing_message.page_id if 'page_id' in pm_keys else None - result_queue_name = processing_message.result_queue_name if 'result_queue_name' in pm_keys else None - callback_url = processing_message.callback_url if 'callback_url' in pm_keys else None - internal_callback_url = processing_message.internal_callback_url if 'internal_callback_url' in pm_keys else None + output_file_grps = processing_message.output_file_grps if "output_file_grps" in pm_keys else None + path_to_mets = processing_message.path_to_mets if "path_to_mets" in pm_keys else None + workspace_id = processing_message.workspace_id if "workspace_id" in pm_keys else None + page_id = processing_message.page_id if "page_id" in pm_keys else None parameters = processing_message.parameters if processing_message.parameters else {} if not path_to_mets and not workspace_id: - raise ValueError(f'`path_to_mets` nor `workspace_id` was set in the ocrd processing message') + msg = f"Both 'path_to_mets' and 'workspace_id' are missing in the OcrdProcessingMessage." + self.log.exception(msg) + raise ValueError(msg) - if path_to_mets: - mets_server_url = sync_db_get_workspace(workspace_mets_path=path_to_mets).mets_server_url + mets_server_url = sync_db_get_workspace(workspace_mets_path=path_to_mets).mets_server_url if not path_to_mets and workspace_id: path_to_mets = sync_db_get_workspace(workspace_id).workspace_mets_path mets_server_url = sync_db_get_workspace(workspace_id).mets_server_url execution_failed = False - self.log.debug(f'Invoking processor: {self.processor_name}') + self.log.debug(f"Invoking processor: {self.processor_name}") start_time = datetime.now() job_log_file = get_processing_job_logging_file_path(job_id=job_id) sync_db_update_processing_job( job_id=job_id, - state=StateEnum.running, + state=JobState.running, path_to_mets=path_to_mets, start_time=start_time, log_file_path=job_log_file @@ -225,73 +195,62 @@ def process_message(self, processing_message: OcrdProcessingMessage) -> None: mets_server_url=mets_server_url ) except Exception as error: - self.log.debug(f"processor_name: {self.processor_name}, path_to_mets: {path_to_mets}, " - f"input_grps: {input_file_grps}, output_file_grps: {output_file_grps}, " - f"page_id: {page_id}, parameters: {parameters}") - self.log.exception(error) + message = ( + f"processor_name: {self.processor_name}, " + f"path_to_mets: {path_to_mets}, " + f"input_file_grps: {input_file_grps}, " + f"output_file_grps: {output_file_grps}, " + f"page_id: {page_id}, " + f"parameters: {parameters}" + ) + self.log.exception(f"{message}, error: {error}") execution_failed = True end_time = datetime.now() exec_duration = calculate_execution_time(start_time, end_time) - job_state = StateEnum.success if not execution_failed else StateEnum.failed + job_state = JobState.success if not execution_failed else JobState.failed sync_db_update_processing_job( job_id=job_id, state=job_state, end_time=end_time, - exec_time=f'{exec_duration} ms' + exec_time=f"{exec_duration} ms" ) result_message = OcrdResultMessage( job_id=job_id, state=job_state.value, path_to_mets=path_to_mets, # May not be always available - workspace_id=workspace_id + workspace_id=workspace_id if workspace_id else '' ) - self.log.info(f'Result message: {result_message.__dict__}') + self.publish_result_to_all(processing_message=processing_message, result_message=result_message) + + def publish_result_to_all(self, processing_message: OcrdProcessingMessage, result_message: OcrdResultMessage): + pm_keys = processing_message.__dict__.keys() + result_queue_name = processing_message.result_queue_name if "result_queue_name" in pm_keys else None + callback_url = processing_message.callback_url if "callback_url" in pm_keys else None + internal_callback_url = processing_message.internal_callback_url if "internal_callback_url" in pm_keys else None + + self.log.info(f"Result message: {result_message.__dict__}") # If the result_queue field is set, send the result message to a result queue if result_queue_name: + self.log.info(f"Publishing result to message queue: {result_queue_name}") self.publish_to_result_queue(result_queue_name, result_message) if callback_url: + self.log.info(f"Publishing result to user defined callback url: {callback_url}") # If the callback_url field is set, # post the result message (callback to a user defined endpoint) post_to_callback_url(self.log, callback_url, result_message) if internal_callback_url: + self.log.info(f"Publishing result to internal callback url (Processing Server): {callback_url}") # If the internal callback_url field is set, # post the result message (callback to Processing Server endpoint) post_to_callback_url(self.log, internal_callback_url, result_message) def publish_to_result_queue(self, result_queue: str, result_message: OcrdResultMessage): - if self.rmq_publisher is None: + if not self.rmq_publisher: self.connect_publisher() # create_queue method is idempotent - nothing happens if # a queue with the specified name already exists self.rmq_publisher.create_queue(queue_name=result_queue) self.log.info(f'Publishing result message to queue: {result_queue}') encoded_result_message = OcrdResultMessage.encode_yml(result_message) - self.rmq_publisher.publish_to_queue( - queue_name=result_queue, - message=encoded_result_message - ) - - def create_queue( - self, - connection_attempts: int = config.OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS, - retry_delay: int = 3) -> None: - """Create the queue for this worker - - Originally only the processing-server created the queues for the workers according to the - configuration file. This is intended to make external deployment of workers possible. - """ - if self.rmq_publisher is None: - attempts_left = connection_attempts if connection_attempts > 0 else 1 - while attempts_left > 0: - try: - self.connect_publisher() - break - except AMQPConnectionError as e: - if attempts_left <= 1: - raise e - attempts_left -= 1 - sleep(retry_delay) - - # the following function is idempotent - self.rmq_publisher.create_queue(queue_name=self.processor_name) + self.rmq_publisher.publish_to_queue(queue_name=result_queue, message=encoded_result_message) diff --git a/src/ocrd_network/processor_server.py b/src/ocrd_network/processor_server.py index cacba8461..5aed89d72 100644 --- a/src/ocrd_network/processor_server.py +++ b/src/ocrd_network/processor_server.py @@ -1,19 +1,18 @@ from datetime import datetime -from logging import FileHandler, Formatter from os import getpid -from subprocess import run, PIPE -import uvicorn +from subprocess import run as subprocess_run, PIPE +from uvicorn import run -from fastapi import FastAPI, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, FastAPI, status from fastapi.responses import FileResponse from ocrd_utils import ( initLogging, get_ocrd_tool_json, getLogger, - LOG_FORMAT, parse_json_string_with_comments ) +from .constants import JobState, ServerApiTags from .database import ( DBProcessorJob, db_get_workspace, @@ -21,47 +20,38 @@ db_get_processing_job, initiate_database ) -from .logging import ( +from .logging_utils import ( + configure_file_handler_with_formatter, get_processor_server_logging_file_path, - get_processing_job_logging_file_path, -) -from .models import ( - PYJobInput, - PYJobOutput, - PYOcrdTool, - StateEnum + get_processing_job_logging_file_path ) +from .models import PYJobInput, PYJobOutput, PYOcrdTool from .process_helpers import invoke_processor from .rabbitmq_utils import OcrdResultMessage from .server_utils import ( _get_processor_job, _get_processor_job_log, + raise_http_exception, validate_and_return_mets_path, validate_job_input ) -from .utils import ( - calculate_execution_time, - post_to_callback_url, - generate_id, -) +from .utils import calculate_execution_time, post_to_callback_url, generate_id class ProcessorServer(FastAPI): def __init__(self, mongodb_addr: str, processor_name: str = "", processor_class=None): if not (processor_name or processor_class): - raise ValueError('Either "processor_name" or "processor_class" must be provided') + raise ValueError("Either 'processor_name' or 'processor_class' must be provided") initLogging() super().__init__( on_startup=[self.on_startup], on_shutdown=[self.on_shutdown], - title=f'OCR-D Processor Server', - description='OCR-D Processor Server' + title=f"Network agent - Processor Server", + description="Network agent - Processor Server" ) - self.log = getLogger('ocrd_network.processor_server') + self.log = getLogger("ocrd_network.processor_server") log_file = get_processor_server_logging_file_path(processor_name=processor_name, pid=getpid()) - file_handler = FileHandler(filename=log_file, mode='a') - file_handler.setFormatter(Formatter(LOG_FORMAT)) - self.log.addHandler(file_handler) + configure_file_handler_with_formatter(self.log, log_file=log_file, mode="a") self.db_url = mongodb_addr self.processor_name = processor_name @@ -76,74 +66,72 @@ def __init__(self, mongodb_addr: str, processor_name: str = "", processor_class= raise Exception(f"The ocrd_tool is empty or missing") if not self.processor_name: - self.processor_name = self.ocrd_tool['executable'] + self.processor_name = self.ocrd_tool["executable"] - # Create routes - self.router.add_api_route( - path='/info', + self.add_api_routes_processing() + + async def on_startup(self): + await initiate_database(db_url=self.db_url) + + async def on_shutdown(self) -> None: + """ + TODO: Perform graceful shutdown operations here + """ + pass + + def add_api_routes_processing(self): + processing_router = APIRouter() + processing_router.add_api_route( + path="/info", endpoint=self.get_processor_info, - methods=['GET'], - tags=['Processing'], + methods=["GET"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Get information about this processor.', + summary="Get information about this processor.", response_model=PYOcrdTool, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/run', + processing_router.add_api_route( + path="/run", endpoint=self.create_processor_task, - methods=['POST'], - tags=['Processing'], + methods=["POST"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_202_ACCEPTED, - summary='Submit a job to this processor.', + summary="Submit a job to this processor.", response_model=PYJobOutput, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/job/{job_id}', + processing_router.add_api_route( + path="/job/{job_id}", endpoint=self.get_processor_job, - methods=['GET'], - tags=['Processing'], + methods=["GET"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Get information about a job based on its ID', + summary="Get information about a job based on its ID", response_model=PYJobOutput, response_model_exclude_unset=True, response_model_exclude_none=True ) - - self.router.add_api_route( - path='/log/{job_id}', + processing_router.add_api_route( + path="/log/{job_id}", endpoint=self.get_processor_job_log, - methods=['GET'], - tags=['processing'], + methods=["GET"], + tags=[ServerApiTags.PROCESSING], status_code=status.HTTP_200_OK, - summary='Get the log file of a job id' + summary="Get the log file of a job id" ) - async def on_startup(self): - await initiate_database(db_url=self.db_url) - - async def on_shutdown(self) -> None: - """ - TODO: Perform graceful shutdown operations here - """ - pass - async def get_processor_info(self): if not self.ocrd_tool: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f'Empty or missing ocrd_tool' - ) + message = "Empty or missing ocrd tool json." + raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message) return self.ocrd_tool # Note: The Processing server pushes to a queue, while # the Processor Server creates (pushes to) a background task - async def create_processor_task(self, job_input: PYJobInput): + async def create_processor_task(self, job_input: PYJobInput, background_tasks: BackgroundTasks): validate_job_input(self.log, self.processor_name, self.ocrd_tool, job_input) job_input.path_to_mets = await validate_and_return_mets_path(self.log, job_input) @@ -155,12 +143,13 @@ async def create_processor_task(self, job_input: PYJobInput): **job_input.dict(exclude_unset=True, exclude_none=True), job_id=job_id, processor_name=self.processor_name, - state=StateEnum.queued + state=JobState.queued ) await job.insert() else: job = await db_get_processing_job(job_input.job_id) - await self.run_processor_task(job=job) + # await self.run_processor_task(job=job) + background_tasks.add_task(self.run_processor_task, job) return job.to_job_output() async def run_processor_task(self, job: DBProcessorJob): @@ -169,7 +158,7 @@ async def run_processor_task(self, job: DBProcessorJob): job_log_file = get_processing_job_logging_file_path(job_id=job.job_id) await db_update_processing_job( job_id=job.job_id, - state=StateEnum.running, + state=JobState.running, start_time=start_time, log_file_path=job_log_file ) @@ -195,21 +184,21 @@ async def run_processor_task(self, job: DBProcessorJob): execution_failed = True end_time = datetime.now() exec_duration = calculate_execution_time(start_time, end_time) - job_state = StateEnum.success if not execution_failed else StateEnum.failed + job_state = JobState.success if not execution_failed else JobState.failed await db_update_processing_job( job_id=job.job_id, state=job_state, end_time=end_time, - exec_time=f'{exec_duration} ms' + exec_time=f"{exec_duration} ms" ) result_message = OcrdResultMessage( job_id=job.job_id, state=job_state.value, path_to_mets=job.path_to_mets, # May not be always available - workspace_id=job.workspace_id + workspace_id=job.workspace_id if job.workspace_id else '' ) - self.log.info(f'Result message: {result_message}') + self.log.info(f"Result message: {result_message}") if job.callback_url: # If the callback_url field is set, # post the result message (callback to a user defined endpoint) @@ -226,8 +215,8 @@ def get_ocrd_tool(self): # The way of accessing ocrd tool like in the line below may be problematic # ocrd_tool = self.processor_class(workspace=None, version=True).ocrd_tool ocrd_tool = parse_json_string_with_comments( - run( - [self.processor_name, '--dump-json'], + subprocess_run( + [self.processor_name, "--dump-json"], stdout=PIPE, check=True, universal_newlines=True @@ -247,8 +236,8 @@ def get_version(self) -> str: # version_str = self.processor_class(workspace=None, version=True).version return version_str """ - version_str = run( - [self.processor_name, '--version'], + version_str = subprocess_run( + [self.processor_name, "--version"], stdout=PIPE, check=True, universal_newlines=True @@ -256,10 +245,10 @@ def get_version(self) -> str: return version_str def run_server(self, host, port): - uvicorn.run(self, host=host, port=port) + run(self, host=host, port=port) async def get_processor_job(self, job_id: str) -> PYJobOutput: - return await _get_processor_job(self.log, self.processor_name, job_id) + return await _get_processor_job(self.log, job_id) async def get_processor_job_log(self, job_id: str) -> FileResponse: - return await _get_processor_job_log(self.log, self.processor_name, job_id) + return await _get_processor_job_log(self.log, job_id) diff --git a/src/ocrd_network/rabbitmq_utils/__init__.py b/src/ocrd_network/rabbitmq_utils/__init__.py index 2d5f55e62..93a8249ef 100644 --- a/src/ocrd_network/rabbitmq_utils/__init__.py +++ b/src/ocrd_network/rabbitmq_utils/__init__.py @@ -1,15 +1,26 @@ __all__ = [ - 'RMQConsumer', - 'RMQConnector', - 'RMQPublisher', - 'OcrdProcessingMessage', - 'OcrdResultMessage' + "check_if_queue_exists", + "connect_rabbitmq_consumer", + "connect_rabbitmq_publisher", + "create_message_queues", + "verify_and_parse_mq_uri", + "verify_rabbitmq_available", + "RMQConsumer", + "RMQConnector", + "RMQPublisher", + "OcrdProcessingMessage", + "OcrdResultMessage" ] from .consumer import RMQConsumer from .connector import RMQConnector -from .publisher import RMQPublisher -from .ocrd_messages import ( - OcrdProcessingMessage, - OcrdResultMessage +from .helpers import ( + check_if_queue_exists, + connect_rabbitmq_consumer, + connect_rabbitmq_publisher, + create_message_queues, + verify_and_parse_mq_uri, + verify_rabbitmq_available ) +from .publisher import RMQPublisher +from .ocrd_messages import OcrdProcessingMessage, OcrdResultMessage diff --git a/src/ocrd_network/rabbitmq_utils/connector.py b/src/ocrd_network/rabbitmq_utils/connector.py index 6dbc6ea0d..893d55a21 100644 --- a/src/ocrd_network/rabbitmq_utils/connector.py +++ b/src/ocrd_network/rabbitmq_utils/connector.py @@ -11,15 +11,15 @@ DEFAULT_EXCHANGER_TYPE, DEFAULT_QUEUE, DEFAULT_ROUTER, - RABBIT_MQ_HOST as HOST, - RABBIT_MQ_PORT as PORT, - RABBIT_MQ_VHOST as VHOST, + RABBIT_MQ_HOST, + RABBIT_MQ_PORT, + RABBIT_MQ_VHOST, PREFETCH_COUNT ) class RMQConnector: - def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> None: + def __init__(self, host: str = RABBIT_MQ_HOST, port: int = RABBIT_MQ_PORT, vhost: str = RABBIT_MQ_VHOST) -> None: self._host = host self._port = port self._vhost = vhost @@ -35,33 +35,33 @@ def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> No # keyboard interruption, i.e., CTRL + C self._gracefully_stopped = False + def close_connection(self, reply_code: int = 200, reply_text: str = "Normal shutdown"): + self._connection.close(reply_code=reply_code, reply_text=reply_text) + @staticmethod def declare_and_bind_defaults(connection: BlockingConnection, channel: BlockingChannel) -> None: if connection and connection.is_open: if channel and channel.is_open: # Declare the default exchange agent RMQConnector.exchange_declare( - channel=channel, - exchange_name=DEFAULT_EXCHANGER_NAME, - exchange_type=DEFAULT_EXCHANGER_TYPE, + channel=channel, exchange_name=DEFAULT_EXCHANGER_NAME, exchange_type=DEFAULT_EXCHANGER_TYPE, ) # Declare the default queue RMQConnector.queue_declare(channel, queue_name=DEFAULT_QUEUE) # Bind the default queue to the default exchange RMQConnector.queue_bind( - channel, - queue_name=DEFAULT_QUEUE, - exchange_name=DEFAULT_EXCHANGER_NAME, - routing_key=DEFAULT_ROUTER + channel=channel, queue_name=DEFAULT_QUEUE, + exchange_name=DEFAULT_EXCHANGER_NAME, routing_key=DEFAULT_ROUTER ) + return + raise ConnectionError("The channel is missing or closed.") + raise ConnectionError("The connection is missing or closed.") # Connection related methods @staticmethod def open_blocking_connection( - credentials: PlainCredentials, - host: str = HOST, - port: int = PORT, - vhost: str = VHOST + credentials: PlainCredentials, + host: str = RABBIT_MQ_HOST, port: int = RABBIT_MQ_PORT, vhost: str = RABBIT_MQ_VHOST ) -> BlockingConnection: blocking_connection = BlockingConnection( parameters=ConnectionParameters( @@ -80,15 +80,24 @@ def open_blocking_channel(connection: BlockingConnection) -> Union[BlockingChann if connection and connection.is_open: channel = connection.channel() return channel - return None + raise ConnectionError("The connection is missing or closed.") + + def _authenticate_and_connect(self, username: str, password: str) -> None: + # Delete credentials once connected + credentials = PlainCredentials(username=username, password=password, erase_on_connect=False) + self._connection = RMQConnector.open_blocking_connection( + host=self._host, port=self._port, vhost=self._vhost, credentials=credentials, + ) + self._channel = RMQConnector.open_blocking_channel(self._connection) + if not self._connection: + raise ConnectionError("The connection is missing or closed.") + if not self._channel: + raise ConnectionError("The channel is missing or closed.") @staticmethod def exchange_bind( - channel: BlockingChannel, - destination_exchange: str, - source_exchange: str, - routing_key: str, - arguments: Optional[Any] = None + channel: BlockingChannel, destination_exchange: str, source_exchange: str, routing_key: str, + arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} @@ -99,22 +108,18 @@ def exchange_bind( routing_key=routing_key, arguments=arguments ) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def exchange_declare( - channel: BlockingChannel, - exchange_name: str, - exchange_type: str, - passive: bool = False, - durable: bool = False, - auto_delete: bool = False, - internal: bool = False, - arguments: Optional[Any] = None + channel: BlockingChannel, exchange_name: str, exchange_type: str, passive: bool = False, durable: bool = False, + auto_delete: bool = False, internal: bool = False, arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} if channel and channel.is_open: - exchange = channel.exchange_declare( + channel.exchange_declare( exchange=exchange_name, exchange_type=exchange_type, # Only check to see if the exchange exists @@ -128,25 +133,21 @@ def exchange_declare( # Custom key/value pair arguments for the exchange arguments=arguments ) - return exchange + return + raise ConnectionError("The channel is missing or closed.") @staticmethod - def exchange_delete( - channel: BlockingChannel, - exchange_name: str, - if_unused: bool = False - ) -> None: + def exchange_delete(channel: BlockingChannel, exchange_name: str, if_unused: bool = False) -> None: # Deletes queue only if unused if channel and channel.is_open: channel.exchange_delete(exchange=exchange_name, if_unused=if_unused) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def exchange_unbind( - channel: BlockingChannel, - destination_exchange: str, - source_exchange: str, - routing_key: str, - arguments: Optional[Any] = None + channel: BlockingChannel, destination_exchange: str, source_exchange: str, routing_key: str, + arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} @@ -157,29 +158,24 @@ def exchange_unbind( routing_key=routing_key, arguments=arguments ) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def queue_bind( - channel: BlockingChannel, - queue_name: str, - exchange_name: str, - routing_key: str, - arguments: Optional[Any] = None + channel: BlockingChannel, queue_name: str, exchange_name: str, routing_key: str, arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} if channel and channel.is_open: channel.queue_bind(queue=queue_name, exchange=exchange_name, routing_key=routing_key, arguments=arguments) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def queue_declare( - channel: BlockingChannel, - queue_name: str, - passive: bool = False, - durable: bool = False, - exclusive: bool = False, - auto_delete: bool = False, - arguments: Optional[Any] = None + channel: BlockingChannel, queue_name: str, passive: bool = False, durable: bool = False, + exclusive: bool = False, auto_delete: bool = False, arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} @@ -199,13 +195,11 @@ def queue_declare( arguments=arguments ) return queue + raise ConnectionError("The channel is missing or closed.") @staticmethod def queue_delete( - channel: BlockingChannel, - queue_name: str, - if_unused: bool = False, - if_empty: bool = False + channel: BlockingChannel, queue_name: str, if_unused: bool = False, if_empty: bool = False ) -> None: if channel and channel.is_open: channel.queue_delete( @@ -215,6 +209,8 @@ def queue_delete( # Only delete if the queue is empty if_empty=if_empty ) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def queue_purge(channel: BlockingChannel, queue_name: str) -> None: @@ -223,11 +219,7 @@ def queue_purge(channel: BlockingChannel, queue_name: str) -> None: @staticmethod def queue_unbind( - channel: BlockingChannel, - queue_name: str, - exchange_name: str, - routing_key: str, - arguments: Optional[Any] = None + channel: BlockingChannel, queue_name: str, exchange_name: str, routing_key: str, arguments: Optional[Any] = None ) -> None: if arguments is None: arguments = {} @@ -238,13 +230,23 @@ def queue_unbind( routing_key=routing_key, arguments=arguments ) + return + raise ConnectionError("The channel is missing or closed.") + + def create_queue( + self, queue_name: str, exchange_name: Optional[str] = DEFAULT_EXCHANGER_NAME, + exchange_type: Optional[str] = "direct", passive: bool = False + ) -> None: + RMQConnector.exchange_declare(channel=self._channel, exchange_name=exchange_name, exchange_type=exchange_type) + RMQConnector.queue_declare(channel=self._channel, queue_name=queue_name, passive=passive) + # The queue name is used as a routing key, to keep implementation simple + RMQConnector.queue_bind( + channel=self._channel, queue_name=queue_name, exchange_name=exchange_name, routing_key=queue_name + ) @staticmethod def set_qos( - channel: BlockingChannel, - prefetch_size: int = 0, - prefetch_count: int = PREFETCH_COUNT, - global_qos: bool = False + channel: BlockingChannel, prefetch_size: int = 0, prefetch_count: int = PREFETCH_COUNT, global_qos: bool = False ) -> None: if channel and channel.is_open: channel.basic_qos( @@ -254,19 +256,19 @@ def set_qos( # Should the qos apply to all channels of the connection global_qos=global_qos ) + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def confirm_delivery(channel: BlockingChannel) -> None: if channel and channel.is_open: channel.confirm_delivery() + return + raise ConnectionError("The channel is missing or closed.") @staticmethod def basic_publish( - channel: BlockingChannel, - exchange_name: str, - routing_key: str, - message_body: bytes, - properties: BasicProperties + channel: BlockingChannel, exchange_name: str, routing_key: str, message_body: bytes, properties: BasicProperties ) -> None: if channel and channel.is_open: channel.basic_publish( @@ -275,3 +277,5 @@ def basic_publish( body=message_body, properties=properties ) + return + raise ConnectionError("The channel is missing or closed.") diff --git a/src/ocrd_network/rabbitmq_utils/constants.py b/src/ocrd_network/rabbitmq_utils/constants.py index 21596ef61..9cdcfec87 100644 --- a/src/ocrd_network/rabbitmq_utils/constants.py +++ b/src/ocrd_network/rabbitmq_utils/constants.py @@ -1,30 +1,35 @@ +from ocrd_utils import config + __all__ = [ - 'DEFAULT_EXCHANGER_NAME', - 'DEFAULT_EXCHANGER_TYPE', - 'DEFAULT_QUEUE', - 'DEFAULT_ROUTER', - 'RABBIT_MQ_HOST', - 'RABBIT_MQ_PORT', - 'RABBIT_MQ_VHOST', - 'RECONNECT_WAIT', - 'RECONNECT_TRIES', - 'PREFETCH_COUNT', + "DEFAULT_EXCHANGER_NAME", + "DEFAULT_EXCHANGER_TYPE", + "DEFAULT_QUEUE", + "DEFAULT_ROUTER", + "RABBIT_MQ_HOST", + "RABBIT_MQ_PORT", + "RABBIT_MQ_VHOST", + "RABBITMQ_URI_PATTERN", + "RECONNECT_WAIT", + "RECONNECT_TRIES", + "PREFETCH_COUNT", ] -DEFAULT_EXCHANGER_NAME: str = 'ocrd-network-default' -DEFAULT_EXCHANGER_TYPE: str = 'direct' -DEFAULT_QUEUE: str = 'ocrd-network-default' -DEFAULT_ROUTER: str = 'ocrd-network-default' +DEFAULT_EXCHANGER_NAME: str = "ocrd-network-default" +DEFAULT_EXCHANGER_TYPE: str = "direct" +DEFAULT_QUEUE: str = "ocrd-network-default" +DEFAULT_ROUTER: str = "ocrd-network-default" -# 'rabbit-mq-host' when Dockerized -RABBIT_MQ_HOST: str = 'localhost' +# "rabbit-mq-host" when Dockerized +RABBIT_MQ_HOST: str = "localhost" RABBIT_MQ_PORT: int = 5672 -RABBIT_MQ_VHOST: str = '/' +RABBIT_MQ_VHOST: str = "/" + +RABBITMQ_URI_PATTERN: str = r"^(?:([^:\/?#\s]+):\/{2})?(?:([^@\/?#\s]+)@)?([^\/?#\s]+)?(?:\/([^?#\s]*))?(?:[?]([^#\s]+))?\S*$" # Wait seconds before next reconnect try -RECONNECT_WAIT: int = 5 +RECONNECT_WAIT: int = 10 # Reconnect tries before timeout -RECONNECT_TRIES: int = 3 +RECONNECT_TRIES: int = config.OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS # QOS, i.e., how many messages to consume in a single go # Check here: https://www.rabbitmq.com/consumer-prefetch.html PREFETCH_COUNT: int = 1 diff --git a/src/ocrd_network/rabbitmq_utils/consumer.py b/src/ocrd_network/rabbitmq_utils/consumer.py index 0d8d905ea..96208fe3b 100644 --- a/src/ocrd_network/rabbitmq_utils/consumer.py +++ b/src/ocrd_network/rabbitmq_utils/consumer.py @@ -4,20 +4,14 @@ RabbitMQ documentation. """ from typing import Any, Union -from pika import PlainCredentials from ocrd_utils import getLogger -from .constants import ( - DEFAULT_QUEUE, - RABBIT_MQ_HOST as HOST, - RABBIT_MQ_PORT as PORT, - RABBIT_MQ_VHOST as VHOST -) from .connector import RMQConnector +from .constants import RABBIT_MQ_HOST, RABBIT_MQ_PORT, RABBIT_MQ_VHOST class RMQConsumer(RMQConnector): - def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> None: - self.log = getLogger('ocrd_network.rabbitmq_utils.consumer') + def __init__(self, host: str = RABBIT_MQ_HOST, port: int = RABBIT_MQ_PORT, vhost: str = RABBIT_MQ_VHOST) -> None: + self.log = getLogger("ocrd_network.rabbitmq_utils.consumer") super().__init__(host=host, port=port, vhost=vhost) self.consumer_tag = None self.consuming = False @@ -26,48 +20,23 @@ def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> No self.reconnect_delay = 0 def authenticate_and_connect(self, username: str, password: str) -> None: - credentials = PlainCredentials( - username=username, - password=password, - erase_on_connect=False # Delete credentials once connected - ) - self._connection = RMQConnector.open_blocking_connection( - host=self._host, - port=self._port, - vhost=self._vhost, - credentials=credentials, - ) - self._channel = RMQConnector.open_blocking_channel(self._connection) + super()._authenticate_and_connect(username=username, password=password) RMQConnector.set_qos(self._channel) self.log.info("Set QoS for the consumer") def setup_defaults(self) -> None: RMQConnector.declare_and_bind_defaults(self._connection, self._channel) - def get_one_message( - self, - queue_name: str, - auto_ack: bool = False - ) -> Union[Any, None]: + def get_one_message(self, queue_name: str, auto_ack: bool = False) -> Union[Any, None]: message = None if self._channel and self._channel.is_open: - message = self._channel.basic_get( - queue=queue_name, - auto_ack=auto_ack - ) + message = self._channel.basic_get(queue=queue_name, auto_ack=auto_ack) return message - def configure_consuming( - self, - queue_name: str, - callback_method: Any - ) -> None: - self.log.debug(f'Configuring consuming from queue: {queue_name}') + def configure_consuming(self, queue_name: str, callback_method: Any) -> None: + self.log.debug(f"Configuring consuming from queue: {queue_name}") self._channel.add_on_cancel_callback(self.__on_consumer_cancelled) - self.consumer_tag = self._channel.basic_consume( - queue_name, - callback_method - ) + self.consumer_tag = self._channel.basic_consume(queue_name, callback_method) self.was_consuming = True self.consuming = True @@ -81,10 +50,10 @@ def get_waiting_message_count(self) -> Union[int, None]: return None def __on_consumer_cancelled(self, frame: Any) -> None: - self.log.warning(f'The consumer was cancelled remotely in frame: {frame}') + self.log.warning(f"The consumer was cancelled remotely in frame: {frame}") if self._channel: self._channel.close() def ack_message(self, delivery_tag: int) -> None: - self.log.debug(f'Acknowledging message with delivery tag: {delivery_tag}') + self.log.debug(f"Acknowledging message with delivery tag: {delivery_tag}") self._channel.basic_ack(delivery_tag) diff --git a/src/ocrd_network/rabbitmq_utils/helpers.py b/src/ocrd_network/rabbitmq_utils/helpers.py new file mode 100644 index 000000000..122658d76 --- /dev/null +++ b/src/ocrd_network/rabbitmq_utils/helpers.py @@ -0,0 +1,106 @@ +from logging import Logger +from pika import URLParameters +from pika.exceptions import AMQPConnectionError, ChannelClosedByBroker +from re import match as re_match +from time import sleep +from typing import Dict, List, Union + +from .constants import RABBITMQ_URI_PATTERN, RECONNECT_TRIES, RECONNECT_WAIT +from .consumer import RMQConsumer +from .publisher import RMQPublisher + + +def __connect_rabbitmq_client( + logger: Logger, client_type: str, rmq_data: Dict, attempts: int = RECONNECT_TRIES, delay: int = RECONNECT_WAIT +) -> Union[RMQConsumer, RMQPublisher]: + try: + rmq_host: str = rmq_data["host"] + rmq_port: int = rmq_data["port"] + rmq_vhost: str = rmq_data["vhost"] + rmq_username: str = rmq_data["username"] + rmq_password: str = rmq_data["password"] + except ValueError as error: + raise Exception("Failed to parse RabbitMQ connection data") from error + logger.info(f"Connecting client to RabbitMQ server: {rmq_host}:{rmq_port}{rmq_vhost}") + logger.debug(f"RabbitMQ client authenticates with username: {rmq_username}, password: {rmq_password}") + while attempts > 0: + try: + if client_type == "consumer": + rmq_client = RMQConsumer(host=rmq_host, port=rmq_port, vhost=rmq_vhost) + elif client_type == "publisher": + rmq_client = RMQPublisher(host=rmq_host, port=rmq_port, vhost=rmq_vhost) + else: + raise RuntimeError(f"RabbitMQ client type can be either a consumer or publisher. Got: {client_type}") + rmq_client.authenticate_and_connect(username=rmq_username, password=rmq_password) + return rmq_client + except AMQPConnectionError: + attempts -= 1 + sleep(delay) + continue + raise RuntimeError(f"Failed to establish connection with the RabbitMQ Server. Connection data: {rmq_data}") + + +def connect_rabbitmq_consumer(logger: Logger, rmq_data: Dict) -> RMQConsumer: + rmq_consumer = __connect_rabbitmq_client(logger=logger, client_type="consumer", rmq_data=rmq_data) + logger.info(f"Successfully connected RMQConsumer") + return rmq_consumer + + +def connect_rabbitmq_publisher(logger: Logger, rmq_data: Dict, enable_acks: bool = True) -> RMQPublisher: + rmq_publisher = __connect_rabbitmq_client(logger=logger, client_type="publisher", rmq_data=rmq_data) + if enable_acks: + rmq_publisher.enable_delivery_confirmations() + logger.info("Delivery confirmations are enabled") + logger.info("Successfully connected RMQPublisher") + return rmq_publisher + + +def check_if_queue_exists(logger: Logger, rmq_data: Dict, processor_name: str) -> bool: + rmq_publisher = connect_rabbitmq_publisher(logger, rmq_data) + try: + # Passively checks whether the queue name exists, if not raises ChannelClosedByBroker + rmq_publisher.create_queue(processor_name, passive=True) + return True + except ChannelClosedByBroker as error: + # The created connection was forcibly closed by the RabbitMQ Server + logger.warning(f"Process queue with id '{processor_name}' not existing: {error}") + return False + + +def create_message_queues(logger: Logger, rmq_publisher: RMQPublisher, queue_names: List[str]) -> None: + # TODO: Reconsider and refactor this. + # Added ocrd-dummy by default if not available for the integration tests. + # A proper Processing Worker / Processor Server registration endpoint is needed on the Processing Server side + if "ocrd-dummy" not in queue_names: + queue_names.append("ocrd-dummy") + + for queue_name in queue_names: + # The existence/validity of the worker.name is not tested. + # Even if an ocr-d processor does not exist, the queue is created + logger.info(f"Creating a message queue with id: {queue_name}") + rmq_publisher.create_queue(queue_name=queue_name) + + +def verify_and_parse_mq_uri(rabbitmq_address: str): + """ + Check the full list of available parameters in the docs here: + https://pika.readthedocs.io/en/stable/_modules/pika/connection.html#URLParameters + """ + match = re_match(pattern=RABBITMQ_URI_PATTERN, string=rabbitmq_address) + if not match: + raise ValueError(f"The message queue server address is in wrong format: '{rabbitmq_address}'") + url_params = URLParameters(rabbitmq_address) + parsed_data = { + "username": url_params.credentials.username, + "password": url_params.credentials.password, + "host": url_params.host, + "port": url_params.port, + "vhost": url_params.virtual_host + } + return parsed_data + + +def verify_rabbitmq_available(logger: Logger, rabbitmq_address: str) -> None: + rmq_data = verify_and_parse_mq_uri(rabbitmq_address=rabbitmq_address) + temp_publisher = connect_rabbitmq_publisher(logger, rmq_data, enable_acks=True) + temp_publisher.close_connection() diff --git a/src/ocrd_network/rabbitmq_utils/ocrd_messages.py b/src/ocrd_network/rabbitmq_utils/ocrd_messages.py index 4016cd71b..8b70e8bcd 100644 --- a/src/ocrd_network/rabbitmq_utils/ocrd_messages.py +++ b/src/ocrd_network/rabbitmq_utils/ocrd_messages.py @@ -1,36 +1,26 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -import yaml - +from yaml import dump, safe_load from ocrd_validators import OcrdNetworkMessageValidator class OcrdProcessingMessage: def __init__( - self, - job_id: str, - processor_name: str, - created_time: int, - input_file_grps: List[str], - output_file_grps: Optional[List[str]], - path_to_mets: Optional[str], - workspace_id: Optional[str], - page_id: Optional[str], - result_queue_name: Optional[str], - callback_url: Optional[str], - internal_callback_url: Optional[str], - parameters: Dict[str, Any] = None + self, job_id: str, processor_name: str, created_time: int, input_file_grps: List[str], + output_file_grps: Optional[List[str]], path_to_mets: Optional[str], workspace_id: Optional[str], + page_id: Optional[str], result_queue_name: Optional[str], callback_url: Optional[str], + internal_callback_url: Optional[str], parameters: Dict[str, Any] = None ) -> None: if not job_id: - raise ValueError('job_id must be provided') + raise ValueError("job_id must be provided") if not processor_name: - raise ValueError('processor_name must be provided') + raise ValueError("processor_name must be provided") if not created_time: - raise ValueError('created time must be provided') + raise ValueError("created time must be provided") if not input_file_grps or len(input_file_grps) == 0: - raise ValueError('input_file_grps must be provided and contain at least 1 element') + raise ValueError("input_file_grps must be provided and contain at least 1 element") if not (workspace_id or path_to_mets): - raise ValueError('Either "workspace_id" or "path_to_mets" must be provided') + raise ValueError("Either 'workspace_id' or 'path_to_mets' must be provided") self.job_id = job_id self.processor_name = processor_name @@ -53,55 +43,53 @@ def __init__( self.parameters = parameters if parameters else {} @staticmethod - def encode_yml(ocrd_processing_message: OcrdProcessingMessage) -> bytes: - return yaml.dump(ocrd_processing_message.__dict__, indent=2).encode('utf-8') + def encode_yml(ocrd_processing_message: OcrdProcessingMessage, encode_type: str = "utf-8") -> bytes: + return dump(ocrd_processing_message.__dict__, indent=2).encode(encode_type) @staticmethod - def decode_yml(ocrd_processing_message: bytes) -> OcrdProcessingMessage: - msg = ocrd_processing_message.decode('utf-8') - data = yaml.safe_load(msg) + def decode_yml(ocrd_processing_message: bytes, decode_type: str = "utf-8") -> OcrdProcessingMessage: + msg = ocrd_processing_message.decode(decode_type) + data = safe_load(msg) report = OcrdNetworkMessageValidator.validate_message_processing(data) if not report.is_valid: - raise ValueError(f'Validating the processing message has failed:\n{report.errors}') + raise ValueError(f"Validating the processing message has failed:\n{report.errors}") return OcrdProcessingMessage( - job_id=data.get('job_id', None), - processor_name=data.get('processor_name', None), - created_time=data.get('created_time', None), - path_to_mets=data.get('path_to_mets', None), - workspace_id=data.get('workspace_id', None), - input_file_grps=data.get('input_file_grps', None), - output_file_grps=data.get('output_file_grps', None), - page_id=data.get('page_id', None), - parameters=data.get('parameters', None), - result_queue_name=data.get('result_queue_name', None), - callback_url=data.get('callback_url', None), - internal_callback_url=data.get('internal_callback_url', None) + job_id=data.get("job_id", None), + processor_name=data.get("processor_name", None), + created_time=data.get("created_time", None), + path_to_mets=data.get("path_to_mets", None), + workspace_id=data.get("workspace_id", None), + input_file_grps=data.get("input_file_grps", None), + output_file_grps=data.get("output_file_grps", None), + page_id=data.get("page_id", None), + parameters=data.get("parameters", None), + result_queue_name=data.get("result_queue_name", None), + callback_url=data.get("callback_url", None), + internal_callback_url=data.get("internal_callback_url", None) ) class OcrdResultMessage: - def __init__(self, job_id: str, state: str, - path_to_mets: Optional[str] = None, - workspace_id: Optional[str] = None) -> None: + def __init__(self, job_id: str, state: str, path_to_mets: Optional[str], workspace_id: Optional[str] = '') -> None: self.job_id = job_id self.state = state self.workspace_id = workspace_id self.path_to_mets = path_to_mets @staticmethod - def encode_yml(ocrd_result_message: OcrdResultMessage) -> bytes: - return yaml.dump(ocrd_result_message.__dict__, indent=2).encode('utf-8') + def encode_yml(ocrd_result_message: OcrdResultMessage, encode_type: str = "utf-8") -> bytes: + return dump(ocrd_result_message.__dict__, indent=2).encode(encode_type) @staticmethod - def decode_yml(ocrd_result_message: bytes) -> OcrdResultMessage: - msg = ocrd_result_message.decode('utf-8') - data = yaml.safe_load(msg) + def decode_yml(ocrd_result_message: bytes, decode_type: str = "utf-8") -> OcrdResultMessage: + msg = ocrd_result_message.decode(decode_type) + data = safe_load(msg) report = OcrdNetworkMessageValidator.validate_message_result(data) if not report.is_valid: - raise ValueError(f'Validating the result message has failed:\n{report.errors}') + raise ValueError(f"Validating the result message has failed:\n{report.errors}") return OcrdResultMessage( - job_id=data.get('job_id', None), - state=data.get('state', None), - path_to_mets=data.get('path_to_mets', None), - workspace_id=data.get('workspace_id', None), + job_id=data.get("job_id", None), + state=data.get("state", None), + path_to_mets=data.get("path_to_mets", None), + workspace_id=data.get("workspace_id", ''), ) diff --git a/src/ocrd_network/rabbitmq_utils/publisher.py b/src/ocrd_network/rabbitmq_utils/publisher.py index f77975a7e..a07a2629c 100644 --- a/src/ocrd_network/rabbitmq_utils/publisher.py +++ b/src/ocrd_network/rabbitmq_utils/publisher.py @@ -4,21 +4,15 @@ RabbitMQ documentation. """ from typing import Optional -from pika import BasicProperties, PlainCredentials +from pika import BasicProperties from ocrd_utils import getLogger -from .constants import ( - DEFAULT_EXCHANGER_NAME, - DEFAULT_ROUTER, - RABBIT_MQ_HOST as HOST, - RABBIT_MQ_PORT as PORT, - RABBIT_MQ_VHOST as VHOST -) from .connector import RMQConnector +from .constants import DEFAULT_EXCHANGER_NAME, RABBIT_MQ_HOST, RABBIT_MQ_PORT, RABBIT_MQ_VHOST class RMQPublisher(RMQConnector): - def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> None: - self.log = getLogger('ocrd_network.rabbitmq_utils.publisher') + def __init__(self, host: str = RABBIT_MQ_HOST, port: int = RABBIT_MQ_PORT, vhost: str = RABBIT_MQ_VHOST) -> None: + self.log = getLogger("ocrd_network.rabbitmq_utils.publisher") super().__init__(host=host, port=port, vhost=vhost) self.message_counter = 0 self.deliveries = {} @@ -27,66 +21,20 @@ def __init__(self, host: str = HOST, port: int = PORT, vhost: str = VHOST) -> No self.running = True def authenticate_and_connect(self, username: str, password: str) -> None: - credentials = PlainCredentials( - username=username, - password=password, - erase_on_connect=False # Delete credentials once connected - ) - self._connection = RMQConnector.open_blocking_connection( - host=self._host, - port=self._port, - vhost=self._vhost, - credentials=credentials, - ) - self._channel = RMQConnector.open_blocking_channel(self._connection) + super()._authenticate_and_connect(username=username, password=password) def setup_defaults(self) -> None: RMQConnector.declare_and_bind_defaults(self._connection, self._channel) - def create_queue( - self, - queue_name: str, - exchange_name: Optional[str] = None, - exchange_type: Optional[str] = None, - passive: bool = False - ) -> None: - if exchange_name is None: - exchange_name = DEFAULT_EXCHANGER_NAME - if exchange_type is None: - exchange_type = 'direct' - - RMQConnector.exchange_declare( - channel=self._channel, - exchange_name=exchange_name, - exchange_type=exchange_type - ) - RMQConnector.queue_declare( - channel=self._channel, - queue_name=queue_name, - passive=passive - ) - RMQConnector.queue_bind( - channel=self._channel, - queue_name=queue_name, - exchange_name=exchange_name, - # the routing key matches the queue name - routing_key=queue_name - ) - def publish_to_queue( - self, - queue_name: str, - message: bytes, - exchange_name: Optional[str] = None, - properties: Optional[BasicProperties] = None + self, queue_name: str, message: bytes, exchange_name: Optional[str] = DEFAULT_EXCHANGER_NAME, + properties: Optional[BasicProperties] = None ) -> None: - if exchange_name is None: - exchange_name = DEFAULT_EXCHANGER_NAME if properties is None: - headers = {'ocrd_network default header': 'ocrd_network default header value'} + headers = {"ocrd_network default header": "ocrd_network default header value"} properties = BasicProperties( - app_id='ocrd_network default app id', - content_type='application/json', + app_id="ocrd_network default app id", + content_type="application/json", headers=headers ) @@ -104,8 +52,8 @@ def publish_to_queue( self.message_counter += 1 self.deliveries[self.message_counter] = True - self.log.debug(f'Published message #{self.message_counter}') + self.log.debug(f"Published message #{self.message_counter} to queue: {queue_name}") def enable_delivery_confirmations(self) -> None: - self.log.debug('Enabling delivery confirmations (Confirm.Select RPC)') + self.log.debug("Enabling delivery confirmations (Confirm.Select RPC)") RMQConnector.confirm_delivery(channel=self._channel) diff --git a/src/ocrd_network/runtime_data.py b/src/ocrd_network/runtime_data.py deleted file mode 100644 index 59c658ada..000000000 --- a/src/ocrd_network/runtime_data.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations -from typing import Dict, List - -from .deployment_utils import ( - create_docker_client, - create_ssh_client, - DeployType -) - -__all__ = [ - 'DataHost', - 'DataMongoDB', - 'DataProcessingWorker', - 'DataProcessorServer', - 'DataRabbitMQ' -] - - -class DataHost: - def __init__(self, config: Dict) -> None: - self.address = config['address'] - self.username = config['username'] - self.password = config.get('password', None) - self.keypath = config.get('path_to_privkey', None) - - # These flags are used to track whether a connection - # of the specified type will be required - self.needs_ssh: bool = False - self.needs_docker: bool = False - - self.ssh_client = None - self.docker_client = None - - # TODO: Not sure this is DS is ideal, seems off - self.data_workers: List[DataProcessingWorker] = [] - self.data_servers: List[DataProcessorServer] = [] - - for worker in config.get('workers', []): - name = worker['name'] - count = worker['number_of_instance'] - deploy_type = DeployType.DOCKER if worker.get('deploy_type', None) == 'docker' else DeployType.NATIVE - if not self.needs_ssh and deploy_type == DeployType.NATIVE: - self.needs_ssh = True - if not self.needs_docker and deploy_type == DeployType.DOCKER: - self.needs_docker = True - for _ in range(count): - self.data_workers.append(DataProcessingWorker(self.address, deploy_type, name)) - - for server in config.get('servers', []): - name = server['name'] - port = server['port'] - deploy_type = DeployType.DOCKER if server.get('deploy_type', None) == 'docker' else DeployType.NATIVE - if not self.needs_ssh and deploy_type == DeployType.NATIVE: - self.needs_ssh = True - if not self.needs_docker and deploy_type == DeployType.DOCKER: - self.needs_docker = True - self.data_servers.append(DataProcessorServer(self.address, port, deploy_type, name)) - - # Key: processor_name, Value: list of ports - self.server_ports: dict = {} - - def create_client(self, client_type: str): - if client_type not in ['docker', 'ssh']: - raise ValueError(f'Host client type cannot be of type: {client_type}') - if client_type == 'ssh': - if not self.ssh_client: - self.ssh_client = create_ssh_client( - self.address, self.username, self.password, self.keypath) - return self.ssh_client - if client_type == 'docker': - if not self.docker_client: - self.docker_client = create_docker_client( - self.address, self.username, self.password, self.keypath - ) - return self.docker_client - - -class DataProcessingWorker: - def __init__(self, host: str, deploy_type: DeployType, processor_name: str) -> None: - self.host = host - self.deploy_type = deploy_type - self.processor_name = processor_name - # Assigned when deployed - self.pid = None - - -class DataProcessorServer: - def __init__(self, host: str, port: int, deploy_type: DeployType, processor_name: str) -> None: - self.host = host - self.port = port - self.deploy_type = deploy_type - self.processor_name = processor_name - # Assigned when deployed - self.pid = None - - -class DataMongoDB: - def __init__(self, config: Dict) -> None: - self.address = config['address'] - self.port = int(config['port']) - if 'ssh' in config: - self.ssh_username = config['ssh']['username'] - self.ssh_keypath = config['ssh'].get('path_to_privkey', None) - self.ssh_password = config['ssh'].get('password', None) - else: - self.ssh_username = None - self.ssh_keypath = None - self.ssh_password = None - - if 'credentials' in config: - self.username = config['credentials']['username'] - self.password = config['credentials']['password'] - self.url = f'mongodb://{self.username}:{self.password}@{self.address}:{self.port}' - else: - self.username = None - self.password = None - self.url = f'mongodb://{self.address}:{self.port}' - self.skip_deployment = config.get('skip_deployment', False) - # Assigned when deployed - self.pid = None - - -class DataRabbitMQ: - def __init__(self, config: Dict) -> None: - self.address = config['address'] - self.port = int(config['port']) - if 'ssh' in config: - self.ssh_username = config['ssh']['username'] - self.ssh_keypath = config['ssh'].get('path_to_privkey', None) - self.ssh_password = config['ssh'].get('password', None) - else: - self.ssh_username = None - self.ssh_keypath = None - self.ssh_password = None - - self.vhost = '/' - self.username = config['credentials']['username'] - self.password = config['credentials']['password'] - self.url = f'amqp://{self.username}:{self.password}@{self.address}:{self.port}{self.vhost}' - self.skip_deployment = config.get('skip_deployment', False) - # Assigned when deployed - self.pid = None diff --git a/src/ocrd_network/runtime_data/__init__.py b/src/ocrd_network/runtime_data/__init__.py new file mode 100644 index 000000000..e43be7ae3 --- /dev/null +++ b/src/ocrd_network/runtime_data/__init__.py @@ -0,0 +1,14 @@ +__all__ = [ + "Deployer", + "DataHost", + "DataMongoDB", + "DataNetworkAgent", + "DataRabbitMQ", + "DataProcessingWorker", + "DataProcessorServer" +] + +from .deployer import Deployer +from .hosts import DataHost +from .network_agents import DataNetworkAgent, DataProcessingWorker, DataProcessorServer +from .network_services import DataMongoDB, DataRabbitMQ diff --git a/src/ocrd_network/runtime_data/config_parser.py b/src/ocrd_network/runtime_data/config_parser.py new file mode 100644 index 000000000..23f574678 --- /dev/null +++ b/src/ocrd_network/runtime_data/config_parser.py @@ -0,0 +1,53 @@ +from typing import Dict, List +from yaml import safe_load + +from ocrd_validators import ProcessingServerConfigValidator +from .hosts import DataHost +from .network_services import DataMongoDB, DataRabbitMQ + + +def validate_and_load_config(config_path: str) -> Dict: + # Load and validate the config + with open(config_path) as fin: + ps_config = safe_load(fin) + report = ProcessingServerConfigValidator.validate(ps_config) + if not report.is_valid: + raise Exception(f"Processing-Server configuration file is invalid:\n{report.errors}") + return ps_config + + +# Parse MongoDB data from the Processing Server configuration file +def parse_mongodb_data(db_config: Dict) -> DataMongoDB: + db_ssh = db_config.get("ssh", {}) + db_credentials = db_config.get("credentials", {}) + return DataMongoDB( + host=db_config["address"], port=int(db_config["port"]), ssh_username=db_ssh.get("username", None), + ssh_keypath=db_ssh.get("path_to_privkey", None), ssh_password=db_ssh.get("password", None), + cred_username=db_credentials.get("username", None), cred_password=db_credentials.get("password", None), + skip_deployment=db_config.get("skip_deployment", False) + ) + + +# Parse RabbitMQ data from the Processing Server configuration file +def parse_rabbitmq_data(rmq_config: Dict) -> DataRabbitMQ: + rmq_ssh = rmq_config.get("ssh", {}) + rmq_credentials = rmq_config.get("credentials", {}) + return DataRabbitMQ( + host=rmq_config["address"], port=int(rmq_config["port"]), ssh_username=rmq_ssh.get("username", None), + ssh_keypath=rmq_ssh.get("path_to_privkey", None), ssh_password=rmq_ssh.get("password", None), + cred_username=rmq_credentials.get("username", None), cred_password=rmq_credentials.get("password", None), + skip_deployment=rmq_config.get("skip_deployment", False) + ) + + +def parse_hosts_data(hosts_config: Dict) -> List[DataHost]: + hosts_data: List[DataHost] = [] + for host_config in hosts_config: + hosts_data.append( + DataHost( + host=host_config["address"], username=host_config["username"], + password=host_config.get("password", None), keypath=host_config.get("path_to_privkey", None), + workers=host_config.get("workers", []), servers=host_config.get("servers", []) + ) + ) + return hosts_data diff --git a/src/ocrd_network/deployment_utils.py b/src/ocrd_network/runtime_data/connection_clients.py similarity index 54% rename from src/ocrd_network/deployment_utils.py rename to src/ocrd_network/runtime_data/connection_clients.py index 8c3ff7e46..67002a498 100644 --- a/src/ocrd_network/deployment_utils.py +++ b/src/ocrd_network/runtime_data/connection_clients.py @@ -1,35 +1,7 @@ from __future__ import annotations -from enum import Enum from docker import APIClient, DockerClient from docker.transport import SSHHTTPAdapter from paramiko import AutoAddPolicy, SSHClient -from time import sleep -import re - -from .rabbitmq_utils import RMQPublisher -from pymongo import MongoClient - -__all__ = [ - 'create_docker_client', - 'create_ssh_client', - 'DeployType', - 'verify_mongodb_available', - 'verify_rabbitmq_available' -] - - -def create_ssh_client(address: str, username: str, password: str = "", keypath: str = "") -> SSHClient: - client = SSHClient() - client.set_missing_host_key_policy(AutoAddPolicy) - try: - client.connect(hostname=address, username=username, password=password, key_filename=keypath) - except Exception as error: - raise Exception(f"Error creating SSHClient of host '{address}', reason: {error}") from error - return client - - -def create_docker_client(address: str, username: str, password: str = "", keypath: str = "") -> CustomDockerClient: - return CustomDockerClient(username, address, password=password, keypath=keypath) class CustomDockerClient(DockerClient): @@ -54,23 +26,29 @@ class CustomDockerClient(DockerClient): """ def __init__(self, user: str, host: str, **kwargs) -> None: - # the super-constructor is not called on purpose: it solely instantiates the APIClient. The - # missing `version` in that call would raise an error. APIClient is provided here as a - # replacement for what the super-constructor does + # The super-constructor is not called on purpose. It solely instantiates the APIClient. + # Missing 'version' in that call would raise an error. + # The APIClient is provided here as a replacement for what the super-constructor does if not (user and host): - raise ValueError('Missing argument: user and host must both be provided') - if ('password' not in kwargs) != ('keypath' not in kwargs): - raise ValueError('Missing argument: one of password and keyfile is needed') - self.api = APIClient(f'ssh://{host}', use_ssh_client=True, version='1.41') - ssh_adapter = self.CustomSshHttpAdapter(f'ssh://{user}@{host}:22', **kwargs) - self.api.mount('http+docker://ssh', ssh_adapter) + raise ValueError("Missing 'user' and 'host' - both must be provided") + if ("password" in kwargs) and ("keypath" in kwargs): + if kwargs["password"] and kwargs["keypath"]: + raise ValueError("Both 'password' and 'keypath' provided - one must be provided") + if ("password" not in kwargs) and ("keypath" not in kwargs): + raise ValueError("Missing 'password' or 'keypath' - one must be provided") + self.api = APIClient(base_url=f"ssh://{host}", use_ssh_client=True, version="1.41") + self.api.mount( + prefix="http+docker://ssh", adapter=self.CustomSshHttpAdapter(base_url=f"ssh://{user}@{host}:22", **kwargs) + ) class CustomSshHttpAdapter(SSHHTTPAdapter): def __init__(self, base_url, password: str = "", keypath: str = "") -> None: self.password = password self.keypath = keypath - if bool(self.password) == bool(self.keypath): - raise Exception("Either 'password' or 'keypath' must be provided") + if not self.password and not self.keypath: + raise Exception("Missing 'password' or 'keypath' - one must be provided") + if self.password and self.keypath: + raise Exception("Both 'password' and 'keypath' provided - one must be provided") super().__init__(base_url) def _create_paramiko_client(self, base_url: str) -> None: @@ -80,44 +58,21 @@ def _create_paramiko_client(self, base_url: str) -> None: """ super()._create_paramiko_client(base_url) if self.password: - self.ssh_params['password'] = self.password + self.ssh_params["password"] = self.password elif self.keypath: - self.ssh_params['key_filename'] = self.keypath + self.ssh_params["key_filename"] = self.keypath self.ssh_client.set_missing_host_key_policy(AutoAddPolicy) -def verify_rabbitmq_available( - host: str, - port: int, - vhost: str, - username: str, - password: str -) -> None: - max_waiting_steps = 15 - while max_waiting_steps > 0: - try: - dummy_publisher = RMQPublisher(host=host, port=port, vhost=vhost) - dummy_publisher.authenticate_and_connect(username=username, password=password) - except Exception: - max_waiting_steps -= 1 - sleep(2) - else: - # TODO: Disconnect the dummy_publisher here before returning... - return - raise RuntimeError(f'Cannot connect to RabbitMQ host: {host}, port: {port}, ' - f'vhost: {vhost}, username: {username}') +def create_docker_client(address: str, username: str, password: str = "", keypath: str = "") -> CustomDockerClient: + return CustomDockerClient(username, address, password=password, keypath=keypath) -def verify_mongodb_available(mongo_url: str) -> None: +def create_ssh_client(address: str, username: str, password: str = "", keypath: str = "") -> SSHClient: + client = SSHClient() + client.set_missing_host_key_policy(AutoAddPolicy) try: - client = MongoClient(mongo_url, serverSelectionTimeoutMS=1000.0) - client.admin.command("ismaster") - except Exception: - raise RuntimeError(f'Cannot connect to MongoDB: {re.sub(r":[^@]+@", ":****@", mongo_url)}') - - -class DeployType(Enum): - """ Deploy-Type of the processing worker/processor server. - """ - DOCKER = 1 - NATIVE = 2 + client.connect(hostname=address, username=username, password=password, key_filename=keypath) + except Exception as error: + raise Exception(f"Error creating SSHClient of host '{address}', reason: {error}") from error + return client diff --git a/src/ocrd_network/runtime_data/deployer.py b/src/ocrd_network/runtime_data/deployer.py new file mode 100644 index 000000000..afc395b5c --- /dev/null +++ b/src/ocrd_network/runtime_data/deployer.py @@ -0,0 +1,174 @@ +""" +Abstraction of the deployment functionality for processors. + +The Processing Server provides the configuration parameters to the Deployer agent. +The Deployer agent runs the RabbitMQ Server, MongoDB and the Processing Hosts. +Each Processing Host may have several Processing Workers. +Each Processing Worker is an instance of an OCR-D processor. +""" +from __future__ import annotations +from pathlib import Path +from subprocess import Popen, run as subprocess_run +from time import sleep +from typing import Dict, List, Union + +from ocrd_utils import config, getLogger, safe_filename +from ..logging_utils import get_mets_server_logging_file_path +from ..utils import is_mets_server_running, stop_mets_server +from .config_parser import parse_hosts_data, parse_mongodb_data, parse_rabbitmq_data, validate_and_load_config +from .hosts import DataHost +from .network_services import DataMongoDB, DataRabbitMQ + + +class Deployer: + def __init__(self, config_path: str) -> None: + self.log = getLogger("ocrd_network.deployer") + ps_config = validate_and_load_config(config_path) + self.data_mongo: DataMongoDB = parse_mongodb_data(ps_config["database"]) + self.data_queue: DataRabbitMQ = parse_rabbitmq_data(ps_config["process_queue"]) + self.data_hosts: List[DataHost] = parse_hosts_data(ps_config["hosts"]) + self.internal_callback_url = ps_config.get("internal_callback_url", None) + self.mets_servers: Dict = {} # {"mets_server_url": "mets_server_pid"} + + # TODO: Reconsider this. + def find_matching_network_agents( + self, worker_only: bool = False, server_only: bool = False, docker_only: bool = False, + native_only: bool = False, str_names_only: bool = False, unique_only: bool = False + ) -> Union[List[str], List[object]]: + """Finds and returns a list of matching data objects of type: + `DataProcessingWorker` and `DataProcessorServer`. + + :py:attr:`worker_only` match only worker network agents (DataProcessingWorker) + :py:attr:`server_only` match only server network agents (DataProcessorServer) + :py:attr:`docker_only` match only docker network agents (DataProcessingWorker and DataProcessorServer) + :py:attr:`native_only` match only native network agents (DataProcessingWorker and DataProcessorServer) + :py:attr:`str_names_only` returns the processor_name filed instead of the Data* object + :py:attr:`unique_only` remove duplicate names from the matches + + `worker_only` and `server_only` are mutually exclusive to each other + `docker_only` and `native_only` are mutually exclusive to each other + `unique_only` is allowed only together with `str_names_only` + """ + + if worker_only and server_only: + msg = f"Only 'worker_only' or 'server_only' is allowed, not both." + self.log.exception(msg) + raise ValueError(msg) + if docker_only and native_only: + msg = f"Only 'docker_only' or 'native_only' is allowed, not both." + self.log.exception(msg) + raise ValueError(msg) + if not str_names_only and unique_only: + msg = f"Value 'unique_only' is allowed only together with 'str_names_only'" + self.log.exception(msg) + raise ValueError(msg) + + # Find all matching objects of type DataProcessingWorker or DataProcessorServer + matched_objects = [] + for data_host in self.data_hosts: + if not server_only: + if not docker_only: + for data_worker in data_host.network_agents_worker_native: + matched_objects.append(data_worker) + if not native_only: + for data_worker in data_host.network_agents_worker_docker: + matched_objects.append(data_worker) + if not worker_only: + if not docker_only: + for data_server in data_host.network_agents_server_native: + matched_objects.append(data_server) + if not native_only: + for data_server in data_host.network_agents_server_docker: + matched_objects.append(data_server) + if not str_names_only: + return matched_objects + # Gets only the processor names of the matched objects + matched_names = [match.processor_name for match in matched_objects] + if not unique_only: + return matched_names + # Removes any duplicate entries from matched names + return list(dict.fromkeys(matched_names)) + + def resolve_processor_server_url(self, processor_name) -> str: + processor_server_url = '' + for data_host in self.data_hosts: + processor_server_url = data_host.resolve_processor_server_url(processor_name=processor_name) + return processor_server_url + + def deploy_network_agents(self, mongodb_url: str, rabbitmq_url: str) -> None: + self.log.debug("Deploying processing workers/processor servers...") + for host_data in self.data_hosts: + host_data.deploy_network_agents(logger=self.log, mongodb_url=mongodb_url, rabbitmq_url=rabbitmq_url) + + def stop_network_agents(self) -> None: + self.log.debug("Stopping processing workers/processor servers...") + for host_data in self.data_hosts: + host_data.stop_network_agents(logger=self.log) + + def deploy_rabbitmq(self) -> str: + self.data_queue.deploy_rabbitmq(self.log) + return self.data_queue.service_url + + def stop_rabbitmq(self): + self.data_queue.stop_service_rabbitmq(self.log) + + def deploy_mongodb(self) -> str: + self.data_mongo.deploy_mongodb(self.log) + return self.data_mongo.service_url + + def stop_mongodb(self): + self.data_mongo.stop_service_mongodb(self.log) + + def stop_all(self) -> None: + """ + The order of stopping is important to optimize graceful shutdown in the future. + If RabbitMQ server is stopped before stopping Processing Workers that may have + a bad outcome and leave Processing Workers in an unpredictable state. + """ + self.stop_network_agents() + self.stop_mongodb() + self.stop_rabbitmq() + + def start_unix_mets_server(self, mets_path: str) -> Path: + log_file = get_mets_server_logging_file_path(mets_path=mets_path) + mets_server_url = Path(config.OCRD_NETWORK_SOCKETS_ROOT_DIR, f"{safe_filename(mets_path)}.sock") + if is_mets_server_running(mets_server_url=str(mets_server_url)): + self.log.warning(f"The mets server for {mets_path} is already started: {mets_server_url}") + return mets_server_url + cwd = Path(mets_path).parent + self.log.info(f"Starting UDS mets server: {mets_server_url}") + sub_process = Popen( + args=["nohup", "ocrd", "workspace", "--mets-server-url", f"{mets_server_url}", + "-d", f"{cwd}", "server", "start"], + shell=False, + stdout=open(file=log_file, mode="w"), + stderr=open(file=log_file, mode="a"), + cwd=cwd, + universal_newlines=True + ) + # Wait for the mets server to start + sleep(2) + self.mets_servers[mets_server_url] = sub_process.pid + return mets_server_url + + def stop_unix_mets_server(self, mets_server_url: str, stop_with_pid: bool = False) -> None: + self.log.info(f"Stopping UDS mets server: {mets_server_url}") + if stop_with_pid: + if Path(mets_server_url) not in self.mets_servers: + message = f"Mets server not found at URL: {mets_server_url}" + self.log.exception(message) + raise Exception(message) + mets_server_pid = self.mets_servers[Path(mets_server_url)] + subprocess_run( + args=["kill", "-s", "SIGINT", f"{mets_server_pid}"], + shell=False, + universal_newlines=True + ) + return + # TODO: Reconsider this again + # Not having this sleep here causes connection errors + # on the last request processed by the processing worker. + # Sometimes 3 seconds is enough, sometimes not. + sleep(5) + stop_mets_server(mets_server_url=mets_server_url) + return diff --git a/src/ocrd_network/runtime_data/hosts.py b/src/ocrd_network/runtime_data/hosts.py new file mode 100644 index 000000000..f46a871f6 --- /dev/null +++ b/src/ocrd_network/runtime_data/hosts.py @@ -0,0 +1,225 @@ +from logging import Logger +from time import sleep +from typing import Dict, List, Union + +from .connection_clients import create_docker_client, create_ssh_client +from .network_agents import AgentType, DataNetworkAgent, DataProcessingWorker, DataProcessorServer, DeployType + + +class DataHost: + def __init__( + self, host: str, username: str, password: str, keypath: str, workers: List[Dict], servers: List[Dict] + ) -> None: + self.host = host + self.username = username + self.password = password + self.keypath = keypath + + # These flags are used to track whether a connection of the specified + # type should be created based on the received config file + self.needs_ssh_connector: bool = False + self.needs_docker_connector: bool = False + + # Connection clients, ssh for native deployment, docker for docker deployment + self.ssh_client = None + self.docker_client = None + + # Time to wait between deploying agents + self.wait_between_agent_deploys: float = 0.3 + + # Lists of network agents based on their agent and deployment type + self.network_agents_worker_native = [] + self.network_agents_worker_docker = [] + self.network_agents_server_native = [] + self.network_agents_server_docker = [] + + if not workers: + workers = [] + if not servers: + servers = [] + + self.__parse_network_agents_workers(processing_workers=workers) + self.__parse_network_agents_servers(processor_servers=servers) + + # Used for caching deployed Processor Servers' ports on the current host + # Key: processor_name, Value: list of ports + self.processor_servers_ports: dict = {} + + def __add_deployed_agent_server_port_to_cache(self, processor_name: str, port: int) -> None: + if processor_name not in self.processor_servers_ports: + self.processor_servers_ports[processor_name] = [port] + return + self.processor_servers_ports[processor_name] = self.processor_servers_ports[processor_name].append(port) + + def __append_network_agent_to_lists(self, agent_data: DataNetworkAgent) -> None: + if agent_data.deploy_type != DeployType.DOCKER and agent_data.deploy_type != DeployType.NATIVE: + raise ValueError(f"Network agent deploy type is unknown: {agent_data.deploy_type}") + if agent_data.agent_type != AgentType.PROCESSING_WORKER and agent_data.agent_type != AgentType.PROCESSOR_SERVER: + raise ValueError(f"Network agent type is unknown: {agent_data.agent_type}") + + if agent_data.deploy_type == DeployType.NATIVE: + self.needs_ssh_connector = True + if agent_data.agent_type == AgentType.PROCESSING_WORKER: + self.network_agents_worker_native.append(agent_data) + if agent_data.agent_type == AgentType.PROCESSOR_SERVER: + self.network_agents_server_native.append(agent_data) + if agent_data.deploy_type == DeployType.DOCKER: + self.needs_docker_connector = True + if agent_data.agent_type == AgentType.PROCESSING_WORKER: + self.network_agents_worker_docker.append(agent_data) + if agent_data.agent_type == AgentType.PROCESSOR_SERVER: + self.network_agents_server_docker.append(agent_data) + + def __parse_network_agents_servers(self, processor_servers: List[Dict]): + for server in processor_servers: + server_data = DataProcessorServer( + processor_name=server["name"], deploy_type=server["deploy_type"], host=self.host, + port=int(server["port"]), init_by_config=True, pid=None + ) + self.__append_network_agent_to_lists(agent_data=server_data) + + def __parse_network_agents_workers(self, processing_workers: List[Dict]): + for worker in processing_workers: + worker_data = DataProcessingWorker( + processor_name=worker["name"], deploy_type=worker["deploy_type"], host=self.host, + init_by_config=True, pid=None + ) + for _ in range(int(worker["number_of_instance"])): + self.__append_network_agent_to_lists(agent_data=worker_data) + + def create_connection_client(self, client_type: str): + if client_type not in ["docker", "ssh"]: + raise ValueError(f"Host client type cannot be of type: {client_type}") + if client_type == "ssh": + self.ssh_client = create_ssh_client(self.host, self.username, self.password, self.keypath) + return self.ssh_client + if client_type == "docker": + self.docker_client = create_docker_client(self.host, self.username, self.password, self.keypath) + return self.docker_client + + def __deploy_network_agent( + self, logger: Logger, agent_data: Union[DataProcessorServer, DataProcessingWorker], + mongodb_url: str, rabbitmq_url: str + ) -> None: + deploy_type = agent_data.deploy_type + agent_type = agent_data.agent_type + name = agent_data.processor_name + agent_info = f"network agent: {agent_type}, deploy: {deploy_type}, name: {name}, host: {self.host}" + logger.info(f"Deploying {agent_info}") + + connection_client = None + if deploy_type == DeployType.NATIVE: + assert self.ssh_client, f"SSH client connection missing." + connection_client = self.ssh_client + if deploy_type == DeployType.DOCKER: + assert self.docker_client, f"Docker client connection missing." + connection_client = self.docker_client + + if agent_type == AgentType.PROCESSING_WORKER: + agent_data.deploy_network_agent(logger, connection_client, mongodb_url, rabbitmq_url) + if agent_type == AgentType.PROCESSOR_SERVER: + agent_data.deploy_network_agent(logger, connection_client, mongodb_url) + + sleep(self.wait_between_agent_deploys) + + def __deploy_network_agents_workers(self, logger: Logger, mongodb_url: str, rabbitmq_url: str): + logger.info(f"Deploying processing workers on host: {self.host}") + amount_workers = len(self.network_agents_worker_native) + len(self.network_agents_worker_docker) + if not amount_workers: + logger.info(f"No processing workers found to be deployed") + for data_worker in self.network_agents_worker_native: + self.__deploy_network_agent(logger, data_worker, mongodb_url, rabbitmq_url) + for data_worker in self.network_agents_worker_docker: + self.__deploy_network_agent(logger, data_worker, mongodb_url, rabbitmq_url) + + def __deploy_network_agents_servers(self, logger: Logger, mongodb_url: str, rabbitmq_url: str): + logger.info(f"Deploying processor servers on host: {self.host}") + amount_servers = len(self.network_agents_server_native) + len(self.network_agents_server_docker) + if not amount_servers: + logger.info(f"No processor servers found to be deployed") + for data_server in self.network_agents_server_native: + self.__deploy_network_agent(logger, data_server, mongodb_url, rabbitmq_url) + self.__add_deployed_agent_server_port_to_cache(data_server.processor_name, data_server.port) + for data_server in self.network_agents_server_docker: + self.__deploy_network_agent(logger, data_server, mongodb_url, rabbitmq_url) + self.__add_deployed_agent_server_port_to_cache(data_server.processor_name, data_server.port) + + def deploy_network_agents(self, logger: Logger, mongodb_url: str, rabbitmq_url: str) -> None: + if self.needs_ssh_connector and not self.ssh_client: + logger.debug("Creating missing ssh connector before deploying") + self.ssh_client = self.create_connection_client(client_type="ssh") + if self.needs_docker_connector: + logger.debug("Creating missing docker connector before deploying") + self.docker_client = self.create_connection_client(client_type="docker") + self.__deploy_network_agents_workers(logger=logger, mongodb_url=mongodb_url, rabbitmq_url=rabbitmq_url) + self.__deploy_network_agents_servers(logger=logger, mongodb_url=mongodb_url, rabbitmq_url=rabbitmq_url) + if self.ssh_client: + self.ssh_client.close() + self.ssh_client = None + if self.docker_client: + self.docker_client.close() + self.docker_client = None + + def __stop_network_agent(self, logger: Logger, name: str, deploy_type: DeployType, agent_type: AgentType, pid: str): + agent_info = f"network agent: {agent_type}, deploy: {deploy_type}, name: {name}" + if not pid: + logger.warning(f"No pid was passed for {agent_info}") + return + agent_info += f", pid: {pid}" + logger.info(f"Stopping {agent_info}") + if deploy_type == DeployType.NATIVE: + assert self.ssh_client, f"SSH client connection missing" + self.ssh_client.exec_command(f"kill {pid}") + if deploy_type == DeployType.DOCKER: + assert self.docker_client, f"Docker client connection missing" + self.docker_client.containers.get(pid).stop() + + def __stop_network_agents_workers(self, logger: Logger): + logger.info(f"Stopping processing workers on host: {self.host}") + amount_workers = len(self.network_agents_worker_native) + len(self.network_agents_worker_docker) + if not amount_workers: + logger.warning(f"No active processing workers to be stopped.") + for worker in self.network_agents_worker_native: + self.__stop_network_agent(logger, worker.processor_name, worker.deploy_type, worker.agent_type, worker.pid) + self.network_agents_worker_native = [] + for worker in self.network_agents_worker_docker: + self.__stop_network_agent(logger, worker.processor_name, worker.deploy_type, worker.agent_type, worker.pid) + self.network_agents_worker_docker = [] + + def __stop_network_agents_servers(self, logger: Logger): + logger.info(f"Stopping processor servers on host: {self.host}") + amount_servers = len(self.network_agents_server_native) + len(self.network_agents_server_docker) + if not amount_servers: + logger.warning(f"No active processor servers to be stopped.") + for server in self.network_agents_server_native: + self.__stop_network_agent(logger, server.processor_name, server.deploy_type, server.agent_type, server.pid) + self.network_agents_server_native = [] + for server in self.network_agents_server_docker: + self.__stop_network_agent(logger, server.processor_name, server.deploy_type, server.agent_type, server.pid) + self.network_agents_server_docker = [] + + def stop_network_agents(self, logger: Logger): + if self.needs_ssh_connector and not self.ssh_client: + logger.debug("Creating missing ssh connector before stopping") + self.ssh_client = self.create_connection_client(client_type="ssh") + if self.needs_docker_connector and not self.docker_client: + logger.debug("Creating missing docker connector before stopping") + self.docker_client = self.create_connection_client(client_type="docker") + self.__stop_network_agents_workers(logger=logger) + self.__stop_network_agents_servers(logger=logger) + if self.ssh_client: + self.ssh_client.close() + self.ssh_client = None + if self.docker_client: + self.docker_client.close() + self.docker_client = None + + def resolve_processor_server_url(self, processor_name: str) -> str: + processor_server_url = '' + for data_server in self.network_agents_server_docker: + if data_server.processor_name == processor_name: + processor_server_url = f"http://{self.host}:{data_server.port}/" + for data_server in self.network_agents_server_native: + if data_server.processor_name == processor_name: + processor_server_url = f"http://{self.host}:{data_server.port}/" + return processor_server_url diff --git a/src/ocrd_network/runtime_data/network_agents.py b/src/ocrd_network/runtime_data/network_agents.py new file mode 100644 index 000000000..6e25d0450 --- /dev/null +++ b/src/ocrd_network/runtime_data/network_agents.py @@ -0,0 +1,110 @@ +from logging import Logger +from typing import Any + +from re import search as re_search +from ..constants import AgentType, DeployType + + +# TODO: Find appropriate replacement for the hack +def deploy_agent_native_get_pid_hack(logger: Logger, ssh_client, start_cmd: str): + channel = ssh_client.invoke_shell() + stdin, stdout = channel.makefile("wb"), channel.makefile("rb") + logger.debug(f"Executing command: {start_cmd}") + + # TODO: This hack should still be fixed + # Note left from @joschrew + # the only way (I could find) to make it work to start a process in the background and + # return early is this construction. The pid of the last started background process is + # printed with `echo $!` but it is printed inbetween other output. Because of that I added + # `xyz` before and after the code to easily be able to filter out the pid via regex when + # returning from the function + + stdin.write(f"{start_cmd}\n") + stdin.write("echo xyz$!xyz \n exit \n") + output = stdout.read().decode("utf-8") + stdout.close() + stdin.close() + return re_search(r"xyz([0-9]+)xyz", output).group(1) # type: ignore + + +# TODO: Implement the actual method that is missing +def deploy_agent_docker_template(logger: Logger, docker_client, start_cmd: str): + """ + logger.debug(f"Executing command: {start_cmd}") + res = docker_client.containers.run("debian", "sleep 500s", detach=True, remove=True) + assert res and res.id, f"Starting docker network agent has failed with command: {start_cmd}" + return res.id + """ + raise Exception("Deploying docker type agents is not supported yet!") + + +class DataNetworkAgent: + def __init__( + self, processor_name: str, deploy_type: DeployType, agent_type: AgentType, + host: str, init_by_config: bool, pid: Any = None + ) -> None: + self.processor_name = processor_name + self.deploy_type = deploy_type + self.host = host + self.deployed_by_config = init_by_config + self.agent_type = agent_type + # The id is assigned when the agent is deployed + self.pid = pid + + def _start_native_instance(self, logger: Logger, ssh_client, start_cmd: str): + if self.deploy_type != DeployType.NATIVE: + raise RuntimeError(f"Mismatch of deploy type when starting network agent: {self.processor_name}") + agent_pid = deploy_agent_native_get_pid_hack(logger=logger, ssh_client=ssh_client, start_cmd=start_cmd) + return agent_pid + + def _start_docker_instance(self, logger: Logger, docker_client, start_cmd: str): + if self.deploy_type != DeployType.DOCKER: + raise RuntimeError(f"Mismatch of deploy type when starting network agent: {self.processor_name}") + agent_pid = deploy_agent_docker_template(logger=logger, docker_client=docker_client, start_cmd=start_cmd) + return agent_pid + + +class DataProcessingWorker(DataNetworkAgent): + def __init__( + self, processor_name: str, deploy_type: DeployType, host: str, init_by_config: bool, pid: Any = None + ) -> None: + super().__init__( + processor_name=processor_name, host=host, deploy_type=deploy_type, agent_type=AgentType.PROCESSING_WORKER, + init_by_config=init_by_config, pid=pid + ) + + def deploy_network_agent(self, logger: Logger, connector_client, database_url: str, queue_url: str): + if self.deploy_type == DeployType.NATIVE: + start_cmd = f"{self.processor_name} {self.agent_type} --database {database_url} --queue {queue_url} &" + self.pid = self._start_native_instance(logger, connector_client, start_cmd) + return self.pid + if self.deploy_type == DeployType.DOCKER: + # TODO: add real command to start processing worker in docker here + start_cmd = f"" + self.pid = self._start_docker_instance(logger, connector_client, start_cmd) + return self.pid + raise RuntimeError(f"Unknown deploy type of {self.__dict__}") + + +class DataProcessorServer(DataNetworkAgent): + def __init__( + self, processor_name: str, deploy_type: DeployType, host: str, port: int, init_by_config: bool, pid: Any = None + ) -> None: + super().__init__( + processor_name=processor_name, host=host, deploy_type=deploy_type, agent_type=AgentType.PROCESSOR_SERVER, + init_by_config=init_by_config, pid=pid + ) + self.port = port + + def deploy_network_agent(self, logger: Logger, connector_client, database_url: str): + agent_address = f"{self.host}:{self.port}" + if self.deploy_type == DeployType.NATIVE: + start_cmd = f"{self.processor_name} {self.agent_type} --address {agent_address} --database {database_url} &" + self.pid = self._start_native_instance(logger, connector_client, start_cmd) + return self.pid + if self.deploy_type == DeployType.DOCKER: + # TODO: add real command to start processor server in docker here + start_cmd = f"" + self.pid = self._start_docker_instance(logger, connector_client, start_cmd) + return self.pid + raise RuntimeError(f"Unknown deploy type of {self.__dict__}") diff --git a/src/ocrd_network/runtime_data/network_services.py b/src/ocrd_network/runtime_data/network_services.py new file mode 100644 index 000000000..3b4c52a0b --- /dev/null +++ b/src/ocrd_network/runtime_data/network_services.py @@ -0,0 +1,160 @@ +from __future__ import annotations +from logging import Logger +from typing import Any, Dict, List, Optional, Union + +from ..constants import DOCKER_IMAGE_MONGO_DB, DOCKER_IMAGE_RABBIT_MQ, DOCKER_RABBIT_MQ_FEATURES +from ..database import verify_mongodb_available +from ..rabbitmq_utils import verify_rabbitmq_available +from .connection_clients import create_docker_client + + +class DataNetworkService: + def __init__( + self, host: str, port: int, ssh_username: str, ssh_keypath: str, ssh_password: str, + cred_username: str, cred_password: str, service_url: str, skip_deployment: bool, pid: Optional[Any] + ) -> None: + self.host = host + self.port = port + self.ssh_username = ssh_username + self.ssh_keypath = ssh_keypath + self.ssh_password = ssh_password + self.cred_username = cred_username + self.cred_password = cred_password + self.service_url = service_url + self.skip_deployment = skip_deployment + self.pid = pid + + @staticmethod + def deploy_docker_service( + logger: Logger, service_data: Union[DataMongoDB, DataRabbitMQ], image: str, env: Optional[List[str]], + ports_mapping: Optional[Dict], detach: bool = True, remove: bool = True + ) -> None: + if not service_data or not service_data.host: + message = f"Deploying '{image}' has failed - missing service configurations." + logger.exception(message) + raise RuntimeError(message) + logger.info(f"Deploying '{image}' service on '{service_data.host}', detach={detach}, remove={remove}") + logger.info(f"Ports mapping: {ports_mapping}") + logger.info(f"Environment: {env}") + client = create_docker_client( + service_data.host, service_data.ssh_username, service_data.ssh_password, service_data.ssh_keypath + ) + result = client.containers.run(image=image, detach=detach, remove=remove, ports=ports_mapping, environment=env) + if not result or not result.id: + message = f"Failed to deploy '{image}' service on host: {service_data.host}" + logger.exception(message) + raise RuntimeError(message) + service_data.pid = result.id + client.close() + + @staticmethod + def stop_docker_service(logger: Logger, service_data: Union[DataMongoDB, DataRabbitMQ]) -> None: + if not service_data.pid: + logger.warning("No running service found") + return + client = create_docker_client( + service_data.host, service_data.ssh_username, service_data.ssh_password, service_data.ssh_keypath + ) + client.containers.get(service_data.pid).stop() + client.close() + + +class DataMongoDB(DataNetworkService): + def __init__( + self, host: str, port: int, ssh_username: Optional[str], ssh_keypath: Optional[str], + ssh_password: Optional[str], cred_username: Optional[str], cred_password: Optional[str], + skip_deployment: bool, protocol: str = "mongodb" + ) -> None: + service_url = f"{protocol}://{host}:{port}" + if cred_username and cred_password: + service_url = f"{protocol}://{cred_username}:{cred_password}@{host}:{port}" + super().__init__( + host=host, port=port, ssh_username=ssh_username, ssh_keypath=ssh_keypath, ssh_password=ssh_password, + cred_username=cred_username, cred_password=cred_password, service_url=service_url, + skip_deployment=skip_deployment, pid=None + ) + + def deploy_mongodb( + self, logger: Logger, image: str = DOCKER_IMAGE_MONGO_DB, detach: bool = True, remove: bool = True, + env: Optional[List[str]] = None, ports_mapping: Optional[Dict] = None + ) -> str: + if self.skip_deployment: + logger.debug("MongoDB is managed externally. Skipping deployment.") + verify_mongodb_available(self.service_url) + return self.service_url + if not env: + env = [] + if self.cred_username: + env = [ + f"MONGO_INITDB_ROOT_USERNAME={self.cred_username}", + f"MONGO_INITDB_ROOT_PASSWORD={self.cred_password}" + ] + if not ports_mapping: + ports_mapping = {27017: self.port} + self.deploy_docker_service(logger, self, image, env, ports_mapping, detach, remove) + verify_mongodb_available(self.service_url) + mongodb_host_info = f"{self.host}:{self.port}" + logger.info(f"The MongoDB was deployed on host: {mongodb_host_info}") + return self.service_url + + def stop_service_mongodb(self, logger: Logger) -> None: + if self.skip_deployment: + return + logger.info("Stopping the MongoDB service...") + self.stop_docker_service(logger, service_data=self) + self.pid = None + logger.info("The MongoDB service is stopped") + + +class DataRabbitMQ(DataNetworkService): + def __init__( + self, host: str, port: int, ssh_username: Optional[str], ssh_keypath: Optional[str], + ssh_password: Optional[str], cred_username: Optional[str], cred_password: Optional[str], + skip_deployment: bool, protocol: str = "amqp", vhost: str = "/" + ) -> None: + self.vhost = f"/{vhost}" if vhost != "/" else vhost + service_url = f"{protocol}://{host}:{port}{self.vhost}" + if cred_username and cred_password: + service_url = f"{protocol}://{cred_username}:{cred_password}@{host}:{port}{self.vhost}" + super().__init__( + host=host, port=port, ssh_username=ssh_username, ssh_keypath=ssh_keypath, ssh_password=ssh_password, + cred_username=cred_username, cred_password=cred_password, service_url=service_url, + skip_deployment=skip_deployment, pid=None + ) + + def deploy_rabbitmq( + self, logger: Logger, image: str = DOCKER_IMAGE_RABBIT_MQ, detach: bool = True, remove: bool = True, + env: Optional[List[str]] = None, ports_mapping: Optional[Dict] = None + ) -> str: + rmq_host, rmq_port, rmq_vhost = self.host, int(self.port), self.vhost + rmq_user, rmq_password = self.cred_username, self.cred_password + if self.skip_deployment: + logger.debug(f"RabbitMQ is managed externally. Skipping deployment.") + verify_rabbitmq_available(logger=logger, rabbitmq_address=self.service_url) + return self.service_url + if not env: + env = [ + # The default credentials to be used by the processing workers + f"RABBITMQ_DEFAULT_USER={rmq_user}", + f"RABBITMQ_DEFAULT_PASS={rmq_password}", + # These feature flags are required by default to use the newer version + f"RABBITMQ_FEATURE_FLAGS={DOCKER_RABBIT_MQ_FEATURES}" + ] + if not ports_mapping: + # 5672, 5671 - used by AMQP 0-9-1 and AMQP 1.0 clients without and with TLS + # 15672, 15671: HTTP API clients, management UI and rabbitmq admin, without and with TLS + # 25672: used for internode and CLI tools communication and is allocated from + # a dynamic range (limited to a single port by default, computed as AMQP port + 20000) + ports_mapping = {5672: self.port, 15672: 15672, 25672: 25672} + self.deploy_docker_service(logger, self, image, env, ports_mapping, detach, remove) + verify_rabbitmq_available(logger=logger, rabbitmq_address=self.service_url) + logger.info(f"The RabbitMQ server was deployed on host: {rmq_host}:{rmq_port}{rmq_vhost}") + return self.service_url + + def stop_service_rabbitmq(self, logger: Logger) -> None: + if self.skip_deployment: + return + logger.info("Stopping the RabbitMQ service...") + self.stop_docker_service(logger, service_data=self) + self.pid = None + logger.info("The RabbitMQ service is stopped") diff --git a/src/ocrd_network/server_cache.py b/src/ocrd_network/server_cache.py index 591ee1c8b..b57f3fd23 100644 --- a/src/ocrd_network/server_cache.py +++ b/src/ocrd_network/server_cache.py @@ -1,28 +1,23 @@ from __future__ import annotations from typing import Dict, List -from logging import FileHandler, Formatter -from ocrd_utils import getLogger, LOG_FORMAT +from ocrd_utils import getLogger +from .constants import JobState, SERVER_ALL_PAGES_PLACEHOLDER from .database import db_get_processing_job, db_update_processing_job -from .logging import ( +from .logging_utils import ( + configure_file_handler_with_formatter, get_cache_locked_pages_logging_file_path, get_cache_processing_requests_logging_file_path ) -from .models import PYJobInput, StateEnum - -__all__ = [ - 'CacheLockedPages', - 'CacheProcessingRequests' -] +from .models import PYJobInput +from .utils import call_sync class CacheLockedPages: def __init__(self) -> None: self.log = getLogger("ocrd_network.server_cache.locked_pages") log_file = get_cache_locked_pages_logging_file_path() - log_fh = FileHandler(filename=log_file, mode='a') - log_fh.setFormatter(Formatter(LOG_FORMAT)) - self.log.addHandler(log_fh) + configure_file_handler_with_formatter(self.log, log_file=log_file, mode="a") # Used for keeping track of locked pages for a workspace # Key: `path_to_mets` if already resolved else `workspace_id` @@ -30,93 +25,75 @@ def __init__(self) -> None: # and the values are list of strings representing the locked pages self.locked_pages: Dict[str, Dict[str, List[str]]] = {} # Used as a placeholder to lock all pages when no page_id is specified - self.placeholder_all_pages: str = "all_pages" + self.placeholder_all_pages: str = SERVER_ALL_PAGES_PLACEHOLDER def check_if_locked_pages_for_output_file_grps( - self, - workspace_key: str, - output_file_grps: List[str], - page_ids: List[str] + self, workspace_key: str, output_file_grps: List[str], page_ids: List[str] ) -> bool: if not self.locked_pages.get(workspace_key, None): self.log.debug(f"No entry found in the locked pages cache for workspace key: {workspace_key}") return False - for output_fileGrp in output_file_grps: - if output_fileGrp in self.locked_pages[workspace_key]: - if self.placeholder_all_pages in self.locked_pages[workspace_key][output_fileGrp]: - self.log.debug(f"Caching the received request due to locked output file grp pages") + debug_message = f"Caching the received request due to locked output file grp pages." + for file_group in output_file_grps: + if file_group in self.locked_pages[workspace_key]: + if self.placeholder_all_pages in self.locked_pages[workspace_key][file_group]: + self.log.debug(debug_message) return True - if not set(self.locked_pages[workspace_key][output_fileGrp]).isdisjoint(page_ids): - self.log.debug(f"Caching the received request due to locked output file grp pages") + if not set(self.locked_pages[workspace_key][file_group]).isdisjoint(page_ids): + self.log.debug(debug_message) return True return False - def get_locked_pages( - self, - workspace_key: str - ) -> Dict[str, List[str]]: + def get_locked_pages(self, workspace_key: str) -> Dict[str, List[str]]: if not self.locked_pages.get(workspace_key, None): self.log.debug(f"No locked pages available for workspace key: {workspace_key}") return {} return self.locked_pages[workspace_key] - def lock_pages( - self, - workspace_key: str, - output_file_grps: List[str], - page_ids: List[str] - ) -> None: + def lock_pages(self, workspace_key: str, output_file_grps: List[str], page_ids: List[str]) -> None: if not self.locked_pages.get(workspace_key, None): self.log.debug(f"No entry found in the locked pages cache for workspace key: {workspace_key}") self.log.debug(f"Creating an entry in the locked pages cache for workspace key: {workspace_key}") self.locked_pages[workspace_key] = {} - - for output_fileGrp in output_file_grps: - if output_fileGrp not in self.locked_pages[workspace_key]: - self.log.debug(f"Creating an empty list for output file grp: {output_fileGrp}") - self.locked_pages[workspace_key][output_fileGrp] = [] + for file_group in output_file_grps: + if file_group not in self.locked_pages[workspace_key]: + self.log.debug(f"Creating an empty list for output file grp: {file_group}") + self.locked_pages[workspace_key][file_group] = [] # The page id list is not empty - only some pages are in the request if page_ids: - self.log.debug(f"Locking pages for `{output_fileGrp}`: {page_ids}") - self.locked_pages[workspace_key][output_fileGrp].extend(page_ids) - self.log.debug(f"Locked pages of `{output_fileGrp}`: " - f"{self.locked_pages[workspace_key][output_fileGrp]}") + self.log.debug(f"Locking pages for '{file_group}': {page_ids}") + self.locked_pages[workspace_key][file_group].extend(page_ids) + self.log.debug(f"Locked pages of '{file_group}': " + f"{self.locked_pages[workspace_key][file_group]}") else: # Lock all pages with a single value - self.log.debug(f"Locking pages for `{output_fileGrp}`: {self.placeholder_all_pages}") - self.locked_pages[workspace_key][output_fileGrp].append(self.placeholder_all_pages) - - def unlock_pages( - self, - workspace_key: str, - output_file_grps: List[str], - page_ids: List[str] - ) -> None: + self.log.debug(f"Locking pages for '{file_group}': {self.placeholder_all_pages}") + self.locked_pages[workspace_key][file_group].append(self.placeholder_all_pages) + + def unlock_pages(self, workspace_key: str, output_file_grps: List[str], page_ids: List[str]) -> None: if not self.locked_pages.get(workspace_key, None): self.log.debug(f"No entry found in the locked pages cache for workspace key: {workspace_key}") return - for output_fileGrp in output_file_grps: - if output_fileGrp in self.locked_pages[workspace_key]: + for file_group in output_file_grps: + if file_group in self.locked_pages[workspace_key]: if page_ids: # Unlock the previously locked pages - self.log.debug(f"Unlocking pages of `{output_fileGrp}`: {page_ids}") - self.locked_pages[workspace_key][output_fileGrp] = \ - [x for x in self.locked_pages[workspace_key][output_fileGrp] if x not in page_ids] - self.log.debug(f"Remaining locked pages of `{output_fileGrp}`: " - f"{self.locked_pages[workspace_key][output_fileGrp]}") + self.log.debug(f"Unlocking pages of '{file_group}': {page_ids}") + self.locked_pages[workspace_key][file_group] = \ + [x for x in self.locked_pages[workspace_key][file_group] if x not in page_ids] + self.log.debug(f"Remaining locked pages of '{file_group}': " + f"{self.locked_pages[workspace_key][file_group]}") else: # Remove the single variable used to indicate all pages are locked - self.log.debug(f"Unlocking all pages for: {output_fileGrp}") - self.locked_pages[workspace_key][output_fileGrp].remove(self.placeholder_all_pages) + self.log.debug(f"Unlocking all pages for: {file_group}") + self.locked_pages[workspace_key][file_group].remove(self.placeholder_all_pages) class CacheProcessingRequests: def __init__(self) -> None: self.log = getLogger("ocrd_network.server_cache.processing_requests") log_file = get_cache_processing_requests_logging_file_path() - log_fh = FileHandler(filename=log_file, mode='a') - log_fh.setFormatter(Formatter(LOG_FORMAT)) - self.log.addHandler(log_fh) + configure_file_handler_with_formatter(self.log, log_file=log_file, mode="a") # Used for buffering/caching processing requests in the Processing Server # Key: `path_to_mets` if already resolved else `workspace_id` @@ -128,7 +105,7 @@ def __init__(self) -> None: # Key: `path_to_mets` if already resolved else `workspace_id` # Value: integer which holds the amount of jobs pushed to the RabbitMQ # but no internal callback was yet invoked - self.__processing_counter: Dict[str, int] = {} + self.processing_counter: Dict[str, int] = {} @staticmethod async def __check_if_job_deps_met(dependencies: List[str]) -> bool: @@ -136,20 +113,28 @@ async def __check_if_job_deps_met(dependencies: List[str]) -> bool: for dependency_job_id in dependencies: try: dependency_job_state = (await db_get_processing_job(dependency_job_id)).state + # Found a dependent job whose state is not success + if dependency_job_state != JobState.success: + return False except ValueError: # job_id not (yet) in db. Dependency not met return False - # Found a dependent job whose state is not success - if dependency_job_state != StateEnum.success: - return False return True + def __print_job_input_debug_message(self, job_input: PYJobInput): + debug_message = "Processing job input" + debug_message += f", processor: {job_input.processor_name}" + debug_message += f", page ids: {job_input.page_id}" + debug_message += f", job id: {job_input.job_id}" + debug_message += f", job depends on: {job_input.depends_on}" + self.log.debug(debug_message) + async def consume_cached_requests(self, workspace_key: str) -> List[PYJobInput]: if not self.has_workspace_cached_requests(workspace_key=workspace_key): self.log.debug(f"No jobs to be consumed for workspace key: {workspace_key}") return [] found_consume_requests = [] - for i, current_element in enumerate(self.processing_requests[workspace_key]): + for current_element in self.processing_requests[workspace_key]: # Request has other job dependencies if current_element.depends_on: satisfied_dependencies = await self.__check_if_job_deps_met(current_element.depends_on) @@ -161,15 +146,17 @@ async def consume_cached_requests(self, workspace_key: str) -> List[PYJobInput]: try: (self.processing_requests[workspace_key]).remove(found_element) # self.log.debug(f"Found cached request to be processed: {found_request}") - self.log.debug(f"Found cached request: {found_element.processor_name}, {found_element.page_id}, " - f"{found_element.job_id}, depends_on: {found_element.depends_on}") + self.__print_job_input_debug_message(job_input=found_element) found_requests.append(found_element) except ValueError: - # The ValueError is not an issue since the - # element was removed by another instance + # The ValueError is not an issue since the element was removed by another instance continue return found_requests + @call_sync + async def sync_consume_cached_requests(self, workspace_key: str) -> List[PYJobInput]: + return await self.consume_cached_requests(workspace_key=workspace_key) + def update_request_counter(self, workspace_key: str, by_value: int) -> int: """ A method used to increase/decrease the internal counter of some workspace_key by `by_value`. @@ -177,19 +164,18 @@ def update_request_counter(self, workspace_key: str, by_value: int) -> int: """ # If a record counter of this workspace key does not exist # in the requests counter cache yet, create one and assign 0 - if not self.__processing_counter.get(workspace_key, None): + if not self.processing_counter.get(workspace_key, None): self.log.debug(f"Creating an internal request counter for workspace key: {workspace_key}") - self.__processing_counter[workspace_key] = 0 - self.__processing_counter[workspace_key] = self.__processing_counter[workspace_key] + by_value - return self.__processing_counter[workspace_key] + self.processing_counter[workspace_key] = 0 + self.processing_counter[workspace_key] = self.processing_counter[workspace_key] + by_value + return self.processing_counter[workspace_key] def cache_request(self, workspace_key: str, data: PYJobInput): # If a record queue of this workspace key does not exist in the requests cache if not self.processing_requests.get(workspace_key, None): self.log.debug(f"Creating an internal request queue for workspace_key: {workspace_key}") self.processing_requests[workspace_key] = [] - self.log.debug(f"Caching request: {data.processor_name}, {data.page_id}, " - f"{data.job_id}, depends_on: {data.depends_on}") + self.__print_job_input_debug_message(job_input=data) # Add the processing request to the end of the internal queue self.processing_requests[workspace_key].append(data) @@ -206,32 +192,37 @@ async def cancel_dependent_jobs(self, workspace_key: str, processing_job_id: str for cancel_element in found_cancel_requests: try: self.processing_requests[workspace_key].remove(cancel_element) - self.log.debug(f"For job id: `{processing_job_id}`, " - f"cancelling: {cancel_element.job_id}") + self.log.debug(f"For job id: '{processing_job_id}', cancelling job id: '{cancel_element.job_id}'") cancelled_jobs.append(cancel_element) - await db_update_processing_job(job_id=cancel_element.job_id, state=StateEnum.cancelled) + await db_update_processing_job(job_id=cancel_element.job_id, state=JobState.cancelled) # Recursively cancel dependent jobs for the cancelled job recursively_cancelled = await self.cancel_dependent_jobs( - workspace_key=workspace_key, - processing_job_id=cancel_element.job_id + workspace_key=workspace_key, processing_job_id=cancel_element.job_id ) # Add the recursively cancelled jobs to the main list of cancelled jobs cancelled_jobs.extend(recursively_cancelled) except ValueError: - # The ValueError is not an issue since the - # element was removed by another instance + # The ValueError is not an issue since the element was removed by another instance continue return cancelled_jobs + @call_sync + async def sync_cancel_dependent_jobs(self, workspace_key: str, processing_job_id: str) -> List[PYJobInput]: + # A synchronous wrapper around the async method + return await self.cancel_dependent_jobs(workspace_key=workspace_key, processing_job_id=processing_job_id) + async def is_caching_required(self, job_dependencies: List[str]) -> bool: if not len(job_dependencies): - # no dependencies found - return False + return False # no dependencies found if await self.__check_if_job_deps_met(job_dependencies): - # all dependencies are met - return False + return False # all dependencies are met return True + @call_sync + async def sync_is_caching_required(self, job_dependencies: List[str]) -> bool: + # A synchronous wrapper around the async method + return await self.is_caching_required(job_dependencies=job_dependencies) + def has_workspace_cached_requests(self, workspace_key: str) -> bool: if not self.processing_requests.get(workspace_key, None): self.log.debug(f"In processing requests cache, no workspace key found: {workspace_key}") diff --git a/src/ocrd_network/server_utils.py b/src/ocrd_network/server_utils.py index fd7c0f796..cc0c59ec6 100644 --- a/src/ocrd_network/server_utils.py +++ b/src/ocrd_network/server_utils.py @@ -1,94 +1,241 @@ -import re -from fastapi import HTTPException, status +from fastapi import HTTPException, status, UploadFile from fastapi.responses import FileResponse +from httpx import AsyncClient, Timeout +from json import dumps, loads +from logging import Logger from pathlib import Path -from typing import List +from requests import get as requests_get +from typing import Dict, List, Union +from urllib.parse import urljoin + +from ocrd.resolver import Resolver +from ocrd.task_sequence import ProcessorTask +from ocrd.workspace import Workspace from ocrd_validators import ParameterValidator -from ocrd_utils import ( - generate_range, - REGEX_PREFIX -) + from .database import ( + db_create_workspace, db_get_processing_job, - db_get_workspace, + db_get_workflow_job, + db_get_workflow_script, + db_get_workspace ) -from .models import PYJobInput, PYJobOutput +from .models import DBProcessorJob, DBWorkflowJob, DBWorkspace, PYJobInput, PYJobOutput +from .rabbitmq_utils import OcrdProcessingMessage +from .utils import ( + calculate_processing_request_timeout, + expand_page_ids, + generate_created_time, + generate_workflow_content, + get_ocrd_workspace_physical_pages +) + + +def create_processing_message(logger: Logger, job: DBProcessorJob) -> OcrdProcessingMessage: + try: + processing_message = OcrdProcessingMessage( + job_id=job.job_id, + processor_name=job.processor_name, + created_time=generate_created_time(), + path_to_mets=job.path_to_mets, + workspace_id=job.workspace_id, + input_file_grps=job.input_file_grps, + output_file_grps=job.output_file_grps, + page_id=job.page_id, + parameters=job.parameters, + result_queue_name=job.result_queue_name, + callback_url=job.callback_url, + internal_callback_url=job.internal_callback_url + ) + return processing_message + except ValueError as error: + message = f"Failed to create OcrdProcessingMessage from DBProcessorJob" + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message, error) + + +async def create_workspace_if_not_exists(logger: Logger, mets_path: str) -> DBWorkspace: + try: + # Core cannot create workspaces by API, but the Processing Server needs + # the workspace in the database. The workspace is created if the path is + # available locally and not existing in the database - since it has not + # been uploaded through the Workspace Server. + db_workspace = await db_create_workspace(mets_path) + return db_workspace + except FileNotFoundError as error: + message = f"Mets file path not existing: {mets_path}" + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message, error) + + +async def get_from_database_workflow_job(logger: Logger, workflow_job_id: str) -> DBWorkflowJob: + try: + workflow_job = await db_get_workflow_job(workflow_job_id) + return workflow_job + except ValueError as error: + message = f"Workflow job with id '{workflow_job_id}' not found in the DB." + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message, error) + + +async def get_from_database_workspace( + logger: Logger, + workspace_id: str = None, + workspace_mets_path: str = None +) -> DBWorkspace: + try: + db_workspace = await db_get_workspace(workspace_id, workspace_mets_path) + return db_workspace + except ValueError as error: + message = f"Workspace with id '{workspace_id}' not found in the DB." + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message, error) + + +def get_page_ids_list(logger: Logger, mets_path: str, page_id: str) -> List[str]: + try: + if page_id: + page_range = expand_page_ids(page_id) + else: + # If no page_id is specified, all physical pages are assigned as page range + page_range = get_ocrd_workspace_physical_pages(mets_path=mets_path) + return page_range + except Exception as error: + message = f"Failed to determine page range for mets path: {mets_path}" + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message, error) -async def _get_processor_job(logger, job_id: str) -> PYJobOutput: +async def _get_processor_job(logger: Logger, job_id: str) -> PYJobOutput: """ Return processing job-information from the database """ try: job = await db_get_processing_job(job_id) return job.to_job_output() - except ValueError as e: - logger.exception(f"Processing job with id '{job_id}' not existing, error: {e}") - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Processing job with id '{job_id}' not existing" - ) + except ValueError as error: + message = f"Processing job with id '{job_id}' not existing." + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message, error) -async def _get_processor_job_log(logger, job_id: str) -> FileResponse: +async def _get_processor_job_log(logger: Logger, job_id: str) -> FileResponse: db_job = await _get_processor_job(logger, job_id) log_file_path = Path(db_job.log_file_path) return FileResponse(path=log_file_path, filename=log_file_path.name) -async def validate_and_return_mets_path(logger, job_input: PYJobInput) -> str: - # This check is done to return early in case the workspace_id is provided - # but the abs mets path cannot be queried from the DB - if not job_input.path_to_mets and job_input.workspace_id: +def request_processor_server_tool_json(logger: Logger, processor_server_base_url: str) -> Dict: + # Request the ocrd tool json from the Processor Server + try: + response = requests_get( + urljoin(base=processor_server_base_url, url="info"), + headers={"Content-Type": "application/json"} + ) + if response.status_code != 200: + message = f"Failed to retrieve tool json from: {processor_server_base_url}, code: {response.status_code}" + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message) + return response.json() + except Exception as error: + message = f"Failed to retrieve ocrd tool json from: {processor_server_base_url}" + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message, error) + + +async def forward_job_to_processor_server( + logger: Logger, job_input: PYJobInput, processor_server_base_url: str +) -> PYJobOutput: + try: + json_data = dumps(job_input.dict(exclude_unset=True, exclude_none=True)) + except Exception as error: + message = f"Failed to json dump the PYJobInput: {job_input}" + raise_http_exception(logger, status.HTTP_500_INTERNAL_SERVER_ERROR, message, error) + + # TODO: The amount of pages should come as a request input + # TODO: cf https://github.com/OCR-D/core/pull/1030/files#r1152551161 + # currently, use 200 as a default + request_timeout = calculate_processing_request_timeout(amount_pages=200, timeout_per_page=20.0) + + # Post a processing job to the Processor Server asynchronously + async with AsyncClient(timeout=Timeout(timeout=request_timeout, connect=30.0)) as client: + response = await client.post( + urljoin(base=processor_server_base_url, url="run"), + headers={"Content-Type": "application/json"}, + json=loads(json_data) + ) + if response.status_code != 202: + message = f"Failed to post '{job_input.processor_name}' job to: {processor_server_base_url}" + raise_http_exception(logger, status.HTTP_500_INTERNAL_SERVER_ERROR, message) + job_output = response.json() + return job_output + + +async def get_workflow_content(logger: Logger, workflow_id: str, workflow: Union[UploadFile, None]) -> str: + if not workflow and not workflow_id: + message = "Either 'workflow' must be uploaded as a file or 'workflow_id' must be provided. Both are missing." + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message) + if workflow_id: try: - db_workspace = await db_get_workspace(job_input.workspace_id) - path_to_mets = db_workspace.workspace_mets_path - except ValueError as e: - logger.exception(f"Workspace with id '{job_input.workspace_id}' not existing: {e}") - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Workspace with id '{job_input.workspace_id}' not existing" - ) - return path_to_mets + db_workflow = await db_get_workflow_script(workflow_id) + return db_workflow.content + except ValueError as error: + message = f"Workflow with id '{workflow_id}' not found" + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message, error) + return await generate_workflow_content(workflow) + + +async def validate_and_return_mets_path(logger: Logger, job_input: PYJobInput) -> str: + if job_input.workspace_id: + db_workspace = await get_from_database_workspace(logger, job_input.workspace_id) + return db_workspace.workspace_mets_path return job_input.path_to_mets -def expand_page_ids(page_id: str) -> List: - page_ids = [] - if not page_id: - return page_ids - for page_id_token in re.split(r',', page_id): - if page_id_token.startswith(REGEX_PREFIX): - page_ids.append(re.compile(page_id_token[len(REGEX_PREFIX):])) - elif '..' in page_id_token: - page_ids += generate_range(*page_id_token.split('..', 1)) - else: - page_ids += [page_id_token] - return page_ids +def parse_workflow_tasks(logger: Logger, workflow_content: str) -> List[ProcessorTask]: + try: + tasks_list = workflow_content.splitlines() + return [ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()] + except ValueError as error: + message = f"Failed parsing processing tasks from a workflow." + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message, error) -def validate_job_input(logger, processor_name: str, ocrd_tool: dict, job_input: PYJobInput) -> None: +def raise_http_exception(logger: Logger, status_code: int, message: str, error: Exception = None) -> None: + logger.exception(f"{message} {error}") + raise HTTPException(status_code=status_code, detail=message) + + +def validate_job_input(logger: Logger, processor_name: str, ocrd_tool: dict, job_input: PYJobInput) -> None: if bool(job_input.path_to_mets) == bool(job_input.workspace_id): - logger.exception("Either 'path_to_mets' or 'workspace_id' must be provided, but not both") - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Either 'path_to_mets' or 'workspace_id' must be provided, but not both" + message = ( + "Wrong processing job input format. " + "Either 'path_to_mets' or 'workspace_id' must be provided. " + "Both are provided or both are missing." ) + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message) if not ocrd_tool: - logger.exception(f"Processor '{processor_name}' not available. Empty or missing ocrd_tool") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Processor '{processor_name}' not available. Empty or missing ocrd_tool" - ) + message = f"Failed parsing processing tasks from a workflow." + raise_http_exception(logger, status.HTTP_404_NOT_FOUND, message) try: report = ParameterValidator(ocrd_tool).validate(dict(job_input.parameters)) - except Exception as e: - logger.exception(f'Failed to validate processing job against the ocrd_tool: {e}') - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail='Failed to validate processing job against the ocrd_tool' - ) - else: if not report.is_valid: - log_msg = f'Failed to validate processing job against the ocrd_tool, errors: {report.errors}' - logger.exception(log_msg) - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=log_msg) + message = f"Failed to validate processing job input against the tool json of processor: {processor_name}\n" + raise_http_exception(logger, status.HTTP_404_BAD_REQUEST, message + report.errors) + except Exception as error: + message = f"Failed to validate processing job input against the ocrd tool json of processor: {processor_name}" + raise_http_exception(logger, status.HTTP_404_BAD_REQUEST, message, error) + + +def validate_workflow(logger: Logger, workflow: str) -> None: + """ + Check whether workflow is not empty and parseable to a lists of ProcessorTask + """ + if not workflow.strip(): + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message="Workflow is invalid, empty.") + try: + tasks_list = workflow.splitlines() + [ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()] + except ValueError as error: + message = "Provided workflow script is invalid, failed to parse ProcessorTasks." + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message, error) + + +def validate_first_task_input_file_groups_existence(logger: Logger, mets_path: str, input_file_grps: List[str]): + # Validate the input file groups of the first task in the workflow + available_groups = Workspace(Resolver(), Path(mets_path).parents[0]).mets.file_groups + for group in input_file_grps: + if group not in available_groups: + message = f"Input file group '{group}' of the first processor not found: {input_file_grps}" + raise_http_exception(logger, status.HTTP_422_UNPROCESSABLE_ENTITY, message) diff --git a/src/ocrd_network/utils.py b/src/ocrd_network/utils.py index 4f66554bc..63cf728e0 100644 --- a/src/ocrd_network/utils.py +++ b/src/ocrd_network/utils.py @@ -1,41 +1,62 @@ +from asyncio import iscoroutine, get_event_loop from datetime import datetime +from fastapi import UploadFile from functools import wraps -from pika import URLParameters -from pymongo import uri_parser as mongo_uri_parser -from re import match as re_match -from requests import get, Session as Session_TCP +from hashlib import md5 +from re import compile as re_compile, split as re_split +from requests import get as requests_get, Session as Session_TCP from requests_unixsocket import Session as Session_UDS -from typing import Dict, List +from time import sleep +from typing import List from uuid import uuid4 -from yaml import safe_load -from ocrd import Resolver, Workspace -from ocrd_validators import ProcessingServerConfigValidator +from ocrd.resolver import Resolver +from ocrd.workspace import Workspace +from ocrd_utils import generate_range, REGEX_PREFIX from .rabbitmq_utils import OcrdResultMessage -from ocrd.task_sequence import ProcessorTask -# Based on: https://gist.github.com/phizaz/20c36c6734878c6ec053245a477572ec def call_sync(func): - import asyncio - + # Based on: https://gist.github.com/phizaz/20c36c6734878c6ec053245a477572ec @wraps(func) def func_wrapper(*args, **kwargs): result = func(*args, **kwargs) - if asyncio.iscoroutine(result): - return asyncio.get_event_loop().run_until_complete(result) + if iscoroutine(result): + return get_event_loop().run_until_complete(result) return result return func_wrapper def calculate_execution_time(start: datetime, end: datetime) -> int: """ - Calculates the difference between `start` and `end` datetime. + Calculates the difference between 'start' and 'end' datetime. Returns the result in milliseconds """ return int((end - start).total_seconds() * 1000) +def calculate_processing_request_timeout(amount_pages: int, timeout_per_page: float = 20.0) -> float: + return amount_pages * timeout_per_page + + +def convert_url_to_uds_format(url: str) -> str: + return f"http+unix://{url.replace('/', '%2F')}" + + +def expand_page_ids(page_id: str) -> List: + page_ids = [] + if not page_id: + return page_ids + for page_id_token in re_split(pattern=r',', string=page_id): + if page_id_token.startswith(REGEX_PREFIX): + page_ids.append(re_compile(pattern=page_id_token[len(REGEX_PREFIX):])) + elif '..' in page_id_token: + page_ids += generate_range(*page_id_token.split(sep='..', maxsplit=1)) + else: + page_ids += [page_id_token] + return page_ids + + def generate_created_time() -> int: return int(datetime.utcnow().timestamp()) @@ -49,63 +70,30 @@ def generate_id() -> str: return str(uuid4()) -def is_url_responsive(url: str, retries: int = 0) -> bool: - while True: - try: - response = get(url) - if response.status_code == 200: - return True - except Exception: - if retries <= 0: - return False - retries -= 1 +async def generate_workflow_content(workflow: UploadFile, encoding: str = "utf-8"): + return (await workflow.read()).decode(encoding) -def validate_and_load_config(config_path: str) -> Dict: - # Load and validate the config - with open(config_path) as fin: - config = safe_load(fin) - report = ProcessingServerConfigValidator.validate(config) - if not report.is_valid: - raise Exception(f'Processing-Server configuration file is invalid:\n{report.errors}') - return config +def generate_workflow_content_hash(workflow_content: str, encoding: str = "utf-8"): + return md5(workflow_content.encode(encoding)).hexdigest() -def verify_database_uri(mongodb_address: str) -> str: - try: - # perform validation check - mongo_uri_parser.parse_uri(uri=mongodb_address, validate=True) - except Exception as error: - raise ValueError(f"The MongoDB address '{mongodb_address}' is in wrong format, {error}") - return mongodb_address - - -def verify_and_parse_mq_uri(rabbitmq_address: str): - """ - Check the full list of available parameters in the docs here: - https://pika.readthedocs.io/en/stable/_modules/pika/connection.html#URLParameters - """ - - uri_pattern = r"^(?:([^:\/?#\s]+):\/{2})?(?:([^@\/?#\s]+)@)?([^\/?#\s]+)?(?:\/([^?#\s]*))?(?:[?]([^#\s]+))?\S*$" - match = re_match(pattern=uri_pattern, string=rabbitmq_address) - if not match: - raise ValueError(f"The message queue server address is in wrong format: '{rabbitmq_address}'") - url_params = URLParameters(rabbitmq_address) - - parsed_data = { - 'username': url_params.credentials.username, - 'password': url_params.credentials.password, - 'host': url_params.host, - 'port': url_params.port, - 'vhost': url_params.virtual_host - } - return parsed_data +def is_url_responsive(url: str, tries: int = 1, wait_time: int = 3) -> bool: + while tries > 0: + try: + if requests_get(url).status_code == 200: + return True + except Exception: + continue + sleep(wait_time) + tries -= 1 + return False def download_ocrd_all_tool_json(ocrd_all_url: str): if not ocrd_all_url: - raise ValueError(f'The URL of ocrd all tool json is empty') - headers = {'Accept': 'application/json'} + raise ValueError(f"The URL of ocrd all tool json is empty") + headers = {"Accept": "application/json"} response = Session_TCP().get(ocrd_all_url, headers=headers) if not response.status_code == 200: raise ValueError(f"Failed to download ocrd all tool json from: '{ocrd_all_url}'") @@ -137,43 +125,24 @@ def get_ocrd_workspace_physical_pages(mets_path: str, mets_server_url: str = Non def is_mets_server_running(mets_server_url: str) -> bool: - protocol = 'tcp' if (mets_server_url.startswith('http://') or mets_server_url.startswith('https://')) else 'uds' - session = Session_TCP() if protocol == 'tcp' else Session_UDS() - mets_server_url = mets_server_url if protocol == 'tcp' else f'http+unix://{mets_server_url.replace("/", "%2F")}' + protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" + session = Session_TCP() if protocol == "tcp" else Session_UDS() + if protocol == "uds": + mets_server_url = convert_url_to_uds_format(mets_server_url) try: - response = session.get(url=f'{mets_server_url}/workspace_path') + response = session.get(url=f"{mets_server_url}/workspace_path") except Exception: return False - if response.status_code == 200: - return True - return False + return response.status_code == 200 def stop_mets_server(mets_server_url: str) -> bool: - protocol = 'tcp' if (mets_server_url.startswith('http://') or mets_server_url.startswith('https://')) else 'uds' - session = Session_TCP() if protocol == 'tcp' else Session_UDS() - mets_server_url = mets_server_url if protocol == 'tcp' else f'http+unix://{mets_server_url.replace("/", "%2F")}' + protocol = "tcp" if (mets_server_url.startswith("http://") or mets_server_url.startswith("https://")) else "uds" + session = Session_TCP() if protocol == "tcp" else Session_UDS() + if protocol == "uds": + mets_server_url = convert_url_to_uds_format(mets_server_url) try: - response = session.delete(url=f'{mets_server_url}/') + response = session.delete(url=f"{mets_server_url}/") except Exception: return False - if response.status_code == 200: - return True - return False - - -def validate_workflow(workflow: str, logger=None) -> bool: - """ Check that workflow is not empty and parseable to a lists of ProcessorTask - """ - if not workflow.strip(): - if logger: - logger.info("Workflow is invalid (empty string)") - return False - try: - tasks_list = workflow.splitlines() - [ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()] - except ValueError as e: - if logger: - logger.info(f"Workflow is invalid, parsing to ProcessorTasks failed: {e}") - return False - return True + return response.status_code == 200 diff --git a/src/ocrd_utils/config.py b/src/ocrd_utils/config.py index cc12a3115..08b9b77a6 100644 --- a/src/ocrd_utils/config.py +++ b/src/ocrd_utils/config.py @@ -153,8 +153,8 @@ def _ocrd_download_timeout_parser(val): description="Default address of Workspace Server to connect to (for `ocrd network client workspace`).", default=(True, '')) -config.add("OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS", - description="Number of attempts for a worker to create its queue. Helpfull if the rabbitmq-server needs time to be fully started", +config.add("OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS", + description="Number of attempts for a RabbitMQ client to connect before failing.", parser=int, default=(True, 3)) diff --git a/src/ocrd_validators/message_processing.schema.yml b/src/ocrd_validators/message_processing.schema.yml index b4363aeba..d2a9432c5 100644 --- a/src/ocrd_validators/message_processing.schema.yml +++ b/src/ocrd_validators/message_processing.schema.yml @@ -52,6 +52,12 @@ properties: parameters: description: Parameters for the used model type: object + agent_type: + description: The network agent type - worker or server + type: string + enum: + - worker + - server result_queue_name: description: Name of the queue to which result is published type: string diff --git a/src/ocrd_validators/message_result.schema.yml b/src/ocrd_validators/message_result.schema.yml index d2c87ba6e..90ca3c856 100644 --- a/src/ocrd_validators/message_result.schema.yml +++ b/src/ocrd_validators/message_result.schema.yml @@ -5,26 +5,25 @@ type: object additionalProperties: false required: - job_id - - status -oneOf: - - required: - - path_to_mets - - required: - - workspace_id + - state + - path_to_mets + - workspace_id properties: job_id: description: The ID of the job type: string format: uuid - status: - description: The current status of the job + state: + description: The current state of the job type: string enum: - CACHED + - CANCELLED - QUEUED - RUNNING - SUCCESS - FAILED + - UNSET path_to_mets: description: Path to a METS file type: string diff --git a/src/ocrd_validators/ocrd_network_message_validator.py b/src/ocrd_validators/ocrd_network_message_validator.py index 486efea43..efba2262c 100644 --- a/src/ocrd_validators/ocrd_network_message_validator.py +++ b/src/ocrd_validators/ocrd_network_message_validator.py @@ -1,10 +1,7 @@ """ Validating ocrd-network messages """ -from .constants import ( - MESSAGE_SCHEMA_PROCESSING, - MESSAGE_SCHEMA_RESULT -) +from .constants import MESSAGE_SCHEMA_PROCESSING, MESSAGE_SCHEMA_RESULT from .json_validator import JsonValidator diff --git a/tests/cli/test_bashlib.py b/tests/cli/test_bashlib.py index 74a623d1b..ab52b6b1b 100644 --- a/tests/cli/test_bashlib.py +++ b/tests/cli/test_bashlib.py @@ -1,6 +1,8 @@ +from contextlib import contextmanager from tests.base import CapturingTestCase as TestCase, main, assets, copy_of_directory import os, sys +from os import environ import traceback import subprocess import tempfile @@ -110,51 +112,38 @@ def test_bashlib_minversion(self): assert "ERROR: ocrd/core is too old" in err def test_bashlib_cp_processor(self): - tool = { - "version": "1.0", - "tools": { - "ocrd-cp": { - "executable": "ocrd-cp", - "description": "dummy processor copying", - "steps": ["preprocessing/optimization"], - "categories": ["Image preprocessing"], - "parameters": { - "message": { - "type": "string", - "default": "", - "description": "message to print on stdout" - } - } - } - } - } - script = (Path(__file__).parent.parent / 'data/bashlib_cp_processor.sh').read_text() - with copy_of_directory(assets.path_to('kant_aufklaerung_1784/data')) as wsdir: - with pushd_popd(wsdir): - with open('ocrd-tool.json', 'w') as toolfile: - json.dump(tool, toolfile) - # run on 1 input - exit_code, out, err = self.invoke_bash( - script, '-I', 'OCR-D-GT-PAGE', '-O', 'OCR-D-GT-PAGE2', '-P', 'message', 'hello world', - executable='ocrd-cp') - print({'exit_code': exit_code, 'out': out, 'err': err}) - assert 'single input fileGrp' in err - assert 'processing PAGE-XML' in err - assert exit_code == 0 - assert 'hello world' in out - path = pathlib.Path('OCR-D-GT-PAGE2') - assert path.is_dir() - assert next(path.glob('*.xml'), None) - # run on 2 inputs - exit_code, out, err = self.invoke_bash( - script, '-I', 'OCR-D-IMG,OCR-D-GT-PAGE', '-O', 'OCR-D-IMG2', - executable='ocrd-cp') - assert 'multiple input fileGrps' in err - assert exit_code == 0 - assert 'ignoring application/vnd.prima.page+xml' in err - path = pathlib.Path('OCR-D-IMG2') - assert path.is_dir() - assert next(path.glob('*.tif'), None) + # script = (Path(__file__).parent.parent / 'data/bashlib_cp_processor.sh').read_text() + # ocrd_tool = json.loads((Path(__file__).parent.parent / 'data/bashlib_cp_processor.ocrd-tool.json').read_text()) + scriptdir = Path(__file__).parent.parent / 'data' + + with copy_of_directory(assets.path_to('kant_aufklaerung_1784/data')) as wsdir, pushd_popd(wsdir): + with open(f'{scriptdir}/ocrd-cp', 'r', encoding='utf-8') as script_f: + script = script_f.read() + with open(f'{scriptdir}/ocrd-cp.ocrd-tool.json', 'r', encoding='utf-8') as tool_in, \ + open(f'{wsdir}/ocrd-tool.json', 'w', encoding='utf-8') as tool_out: + tool_out.write(tool_in.read()) + # run on 1 input + exit_code, out, err = self.invoke_bash( + script, '-I', 'OCR-D-GT-PAGE', '-O', 'OCR-D-GT-PAGE2', '-P', 'message', 'hello world', + executable='ocrd-cp') + print({'exit_code': exit_code, 'out': out, 'err': err}) + assert 'single input fileGrp' in err + assert 'processing PAGE-XML' in err + assert exit_code == 0 + assert 'hello world' in out + path = pathlib.Path('OCR-D-GT-PAGE2') + assert path.is_dir() + assert next(path.glob('*.xml'), None) + # run on 2 inputs + exit_code, out, err = self.invoke_bash( + script, '-I', 'OCR-D-IMG,OCR-D-GT-PAGE', '-O', 'OCR-D-IMG2', + executable='ocrd-cp') + assert 'multiple input fileGrps' in err + assert exit_code == 0 + assert 'ignoring application/vnd.prima.page+xml' in err + path = pathlib.Path('OCR-D-IMG2') + assert path.is_dir() + assert next(path.glob('*.tif'), None) if __name__ == "__main__": main(__file__) diff --git a/tests/conftest.py b/tests/conftest.py index 1f3507107..f7e94b743 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ pytest_plugins = [ "tests.network.fixtures_mongodb", + "tests.network.fixtures_processing_requests", "tests.network.fixtures_rabbitmq" ] diff --git a/tests/data/bashlib_cp_processor.sh b/tests/data/ocrd-cp old mode 100644 new mode 100755 similarity index 91% rename from tests/data/bashlib_cp_processor.sh rename to tests/data/ocrd-cp index 7febaa769..4f652d1a1 --- a/tests/data/bashlib_cp_processor.sh +++ b/tests/data/ocrd-cp @@ -4,7 +4,14 @@ set -eu set -o pipefail MIMETYPE_PAGE=$(ocrd bashlib constants MIMETYPE_PAGE) source $(ocrd bashlib filename) -ocrd__wrap ocrd-tool.json ocrd-cp "$@" +set -x + +_ocrd_tool_json="$0.ocrd-tool.json" +if [[ $_ocrd_tool_json == $0 || ! -e $_ocrd_tool_json ]];then + _ocrd_tool_json='ocrd-tool.json' +fi + +ocrd__wrap $_ocrd_tool_json ocrd-cp "$@" IFS=',' read -ra in_file_grps <<< ${ocrd__argv[input_file_grp]} if ((${#in_file_grps[*]}>1)); then diff --git a/tests/data/ocrd-cp.ocrd-tool.json b/tests/data/ocrd-cp.ocrd-tool.json new file mode 100755 index 000000000..728c144c5 --- /dev/null +++ b/tests/data/ocrd-cp.ocrd-tool.json @@ -0,0 +1,18 @@ +{ + "version": "1.0", + "tools": { + "ocrd-cp": { + "executable": "ocrd-cp", + "description": "dummy processor copying", + "steps": ["preprocessing/optimization"], + "categories": ["Image preprocessing"], + "parameters": { + "message": { + "type": "string", + "default": "", + "description": "message to print on stdout" + } + } + } + } +} diff --git a/tests/model/test_ocrd_mets.py b/tests/model/test_ocrd_mets.py index 4fbc38ed1..cb8d10de6 100644 --- a/tests/model/test_ocrd_mets.py +++ b/tests/model/test_ocrd_mets.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - from datetime import datetime from os.path import join diff --git a/tests/network/config.py b/tests/network/config.py index 646833aee..67c4ff24b 100644 --- a/tests/network/config.py +++ b/tests/network/config.py @@ -1,7 +1,7 @@ from pathlib import Path from tempfile import gettempdir -from src.ocrd_utils.config import OcrdEnvConfig -from src.ocrd_utils.config import _ocrd_download_timeout_parser +from ocrd_utils.config import OcrdEnvConfig +from ocrd_utils.config import _ocrd_download_timeout_parser test_config = OcrdEnvConfig() @@ -74,12 +74,12 @@ ) test_config.add( - name="OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS", + name="OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS", description=""" - Number of attempts for a worker to create its queue. Helpful if the rabbitmq-server needs time to be fully started + Number of attempts for a RabbitMQ client to connect before failing """, parser=int, - default=(True, 1) + default=(True, 3) ) test_config.add( diff --git a/tests/network/fixtures_mongodb.py b/tests/network/fixtures_mongodb.py index 1409bafe7..4b829e5e7 100644 --- a/tests/network/fixtures_mongodb.py +++ b/tests/network/fixtures_mongodb.py @@ -1,6 +1,5 @@ from pytest import fixture -from src.ocrd_network.database import sync_initiate_database -from src.ocrd_network.utils import verify_database_uri +from src.ocrd_network.database import sync_initiate_database, verify_database_uri from tests.network.config import test_config diff --git a/tests/network/fixtures_processing_requests.py b/tests/network/fixtures_processing_requests.py new file mode 100644 index 000000000..55ee0cc9e --- /dev/null +++ b/tests/network/fixtures_processing_requests.py @@ -0,0 +1,17 @@ +from pytest import fixture +from src.ocrd_network.constants import AgentType +from src.ocrd_network.models import PYJobInput + + +@fixture(scope="package", name="processing_request_1") +def fixture_processing_request_1() -> PYJobInput: + workspace_key = "/path/to/mets.xml" + processing_request1 = PYJobInput( + path_to_mets=workspace_key, + input_file_grps=["DEFAULT"], + output_file_grps=["OCR-D-BIN"], + agent_type=AgentType.PROCESSING_WORKER, + page_id="PHYS_0001..PHYS_0003", + parameters={} + ) + yield processing_request1 diff --git a/tests/network/fixtures_rabbitmq.py b/tests/network/fixtures_rabbitmq.py index a3b1300cf..29c5913d3 100644 --- a/tests/network/fixtures_rabbitmq.py +++ b/tests/network/fixtures_rabbitmq.py @@ -1,7 +1,12 @@ +from logging import getLogger from pika.credentials import PlainCredentials from pytest import fixture -from src.ocrd_network.rabbitmq_utils import RMQConnector, RMQConsumer, RMQPublisher -from src.ocrd_network.utils import verify_and_parse_mq_uri +from src.ocrd_network.rabbitmq_utils import ( + connect_rabbitmq_consumer, + connect_rabbitmq_publisher, + RMQConnector, + verify_and_parse_mq_uri +) from tests.network.config import test_config @@ -30,7 +35,7 @@ def fixture_rabbitmq_defaults(): RMQConnector.exchange_declare( channel=test_channel, exchange_name=DEFAULT_EXCHANGER_NAME, - exchange_type='direct', + exchange_type="direct", durable=False ) RMQConnector.queue_declare(channel=test_channel, queue_name=DEFAULT_QUEUE, durable=False) @@ -47,29 +52,12 @@ def fixture_rabbitmq_defaults(): @fixture(scope="package", name="rabbitmq_publisher") def fixture_rabbitmq_publisher(rabbitmq_defaults): rmq_data = verify_and_parse_mq_uri(RABBITMQ_URL) - rmq_publisher = RMQPublisher( - host=rmq_data["host"], - port=rmq_data["port"], - vhost=rmq_data["vhost"] - ) - rmq_publisher.authenticate_and_connect( - username=rmq_data["username"], - password=rmq_data["password"] - ) - rmq_publisher.enable_delivery_confirmations() - yield rmq_publisher + logger = getLogger(name="ocrd_network_testing") + yield connect_rabbitmq_publisher(logger=logger, rmq_data=rmq_data) @fixture(scope="package", name="rabbitmq_consumer") def fixture_rabbitmq_consumer(rabbitmq_defaults): rmq_data = verify_and_parse_mq_uri(RABBITMQ_URL) - rmq_consumer = RMQConsumer( - host=rmq_data["host"], - port=rmq_data["port"], - vhost=rmq_data["vhost"] - ) - rmq_consumer.authenticate_and_connect( - username=rmq_data["username"], - password=rmq_data["password"] - ) - yield rmq_consumer + logger = getLogger(name="ocrd_network_testing") + yield connect_rabbitmq_consumer(logger=logger, rmq_data=rmq_data) diff --git a/tests/network/test_db.py b/tests/network/test_integration_1_db.py similarity index 65% rename from tests/network/test_db.py rename to tests/network/test_integration_1_db.py index 6a6982288..c46ac7571 100644 --- a/tests/network/test_db.py +++ b/tests/network/test_integration_1_db.py @@ -3,7 +3,8 @@ from pathlib import Path from pytest import raises from tests.base import assets -from src.ocrd_network.models import DBProcessorJob, DBWorkflowScript, StateEnum +from src.ocrd_network import JobState +from src.ocrd_network.models import DBProcessorJob, DBWorkflowScript from src.ocrd_network.database import ( sync_db_create_processing_job, sync_db_get_processing_job, @@ -18,61 +19,70 @@ def test_db_processing_job_create(mongo_client): - job_id = f'test_job_id_{datetime.now()}' + job_id = f"test_job_id_{datetime.now()}" + path_to_mets = "/ocrd/dummy/path" + processor_name = "ocrd-dummy" + job_state = JobState.cached + input_file_group = "DEFAULT" + output_file_group = "OCR-D-DUMMY" db_created_processing_job = sync_db_create_processing_job( db_processing_job=DBProcessorJob( job_id=job_id, - processor_name='ocrd-dummy', - state=StateEnum.cached, - path_to_mets='/ocrd/dummy/path', - input_file_grps=['DEFAULT'], - output_file_grps=['OCR-D-DUMMY'] + processor_name=processor_name, + state=job_state, + path_to_mets=path_to_mets, + input_file_grps=[input_file_group], + output_file_grps=[output_file_group] ) ) assert db_created_processing_job db_found_processing_job = sync_db_get_processing_job(job_id=job_id) assert db_found_processing_job assert db_found_processing_job.job_id == job_id - assert db_found_processing_job.processor_name == 'ocrd-dummy' - assert db_found_processing_job.state == StateEnum.cached - assert db_found_processing_job.path_to_mets == '/ocrd/dummy/path' - assert db_found_processing_job.input_file_grps == ['DEFAULT'] - assert db_found_processing_job.output_file_grps == ['OCR-D-DUMMY'] + assert db_found_processing_job.processor_name == processor_name + assert db_found_processing_job.state == job_state + assert db_found_processing_job.path_to_mets == path_to_mets + assert db_found_processing_job.input_file_grps == [input_file_group] + assert db_found_processing_job.output_file_grps == [output_file_group] with raises(ValueError): - sync_db_get_processing_job(job_id='non-existing-id') + sync_db_get_processing_job(job_id="non-existing-id") def test_db_processing_job_update(mongo_client): - job_id = f'test_job_id_{datetime.now()}' + job_id = f"test_job_id_{datetime.now()}" + path_to_mets = "/ocrd/dummy/path" + processor_name = "ocrd-dummy" + input_file_group = "DEFAULT" + output_file_group = "OCR-D-DUMMY" db_created_processing_job = sync_db_create_processing_job( db_processing_job=DBProcessorJob( job_id=job_id, - processor_name='ocrd-dummy', - state=StateEnum.cached, - path_to_mets='/ocrd/dummy/path', - input_file_grps=['DEFAULT'], - output_file_grps=['OCR-D-DUMMY'] + processor_name=processor_name, + state=JobState.cached, + path_to_mets=path_to_mets, + input_file_grps=[input_file_group], + output_file_grps=[output_file_group] ) ) assert db_created_processing_job db_found_processing_job = sync_db_get_processing_job(job_id=job_id) assert db_found_processing_job - db_updated_processing_job = sync_db_update_processing_job(job_id=job_id, state=StateEnum.running) + db_updated_processing_job = sync_db_update_processing_job(job_id=job_id, state=JobState.running) assert db_found_processing_job != db_updated_processing_job db_found_updated_processing_job = sync_db_get_processing_job(job_id=job_id) assert db_found_updated_processing_job assert db_found_updated_processing_job == db_updated_processing_job - assert db_found_updated_processing_job.state == StateEnum.running + assert db_found_updated_processing_job.state == JobState.running with raises(ValueError): - sync_db_update_processing_job(job_id='non-existing', state=StateEnum.running) - sync_db_update_processing_job(job_id=job_id, non_existing_field='dummy_value') - sync_db_update_processing_job(job_id=job_id, processor_name='non-updatable-field') + sync_db_update_processing_job(job_id="non-existing", state=JobState.running) + sync_db_update_processing_job(job_id=job_id, non_existing_field="dummy_value") + sync_db_update_processing_job(job_id=job_id, processor_name="non-updatable-field") def test_db_workspace_create(mongo_client): - mets_path = assets.path_to('kant_aufklaerung_1784/data/mets.xml') + mets_path = assets.path_to("kant_aufklaerung_1784/data/mets.xml") db_created_workspace = sync_db_create_workspace(mets_path=mets_path) assert db_created_workspace assert db_created_workspace.workspace_mets_path == mets_path @@ -81,16 +91,16 @@ def test_db_workspace_create(mongo_client): assert db_found_workspace == db_created_workspace with raises(ValueError): - sync_db_get_workspace(workspace_id='non-existing-id') - sync_db_get_workspace(workspace_mets_path='non-existing-mets') + sync_db_get_workspace(workspace_id="non-existing-id") + sync_db_get_workspace(workspace_mets_path="non-existing-mets") with raises(FileNotFoundError): - sync_db_create_workspace(mets_path='non-existing-mets') + sync_db_create_workspace(mets_path="non-existing-mets") def test_db_workspace_update(mongo_client): - mets_path = assets.path_to('kant_aufklaerung_1784-binarized/data/mets.xml') - dummy_mets_server_url = '/tmp/dummy.sock' + mets_path = assets.path_to("kant_aufklaerung_1784-binarized/data/mets.xml") + dummy_mets_server_url = "/tmp/dummy.sock" db_created_workspace = sync_db_create_workspace(mets_path=mets_path) assert db_created_workspace @@ -115,41 +125,35 @@ def test_db_workspace_update(mongo_client): # TODO: There is no db wrapper implemented due to direct access in the processing server... # TODO2: Should be refactored with proper asset access def create_db_model_workflow_script( - workflow_id: str, - script_path: Path = Path(Path(__file__).parent, "dummy-workflow.txt") + workflow_id: str, + script_path: Path = Path(Path(__file__).parent, "dummy-workflow.txt") ) -> DBWorkflowScript: workflow_id = workflow_id - with open(script_path, 'rb') as fp: + with open(script_path, "rb") as fp: content = (fp.read()).decode("utf-8") content_hash = md5(content.encode("utf-8")).hexdigest() return DBWorkflowScript(workflow_id=workflow_id, content=content, content_hash=content_hash) def test_db_workflow_script_create(mongo_client): - workflow_id = f'test_workflow_{datetime.now()}' + workflow_id = f"test_workflow_{datetime.now()}" db_model_workflow_script = create_db_model_workflow_script(workflow_id=workflow_id) - db_created_workflow_script = sync_db_create_workflow_script( - db_workflow_script=db_model_workflow_script - ) + db_created_workflow_script = sync_db_create_workflow_script(db_workflow_script=db_model_workflow_script) assert db_created_workflow_script db_found_workflow_script = sync_db_get_workflow_script(workflow_id=workflow_id) assert db_found_workflow_script assert db_found_workflow_script == db_created_workflow_script with raises(ValueError): - sync_db_get_workflow_script(workflow_id='non-existing-id') + sync_db_get_workflow_script(workflow_id="non-existing-id") def test_db_find_workflow_script_by_content(mongo_client): - workflow_id = f'test_workflow_{datetime.now()}' - db_model_workflow_script = create_db_model_workflow_script(workflow_id=workflow_id) - db_created_workflow_script = sync_db_create_workflow_script( - db_workflow_script=db_model_workflow_script - ) + workflow_id = f"test_workflow_{datetime.now()}" + db_model_wf_script = create_db_model_workflow_script(workflow_id=workflow_id) + db_created_workflow_script = sync_db_create_workflow_script(db_workflow_script=db_model_wf_script) assert db_created_workflow_script - db_found_workflow_script = sync_db_find_first_workflow_script_by_content( - workflow_id=db_model_workflow_script.workflow_id - ) + db_found_workflow_script = sync_db_find_first_workflow_script_by_content(workflow_id=db_model_wf_script.workflow_id) assert db_found_workflow_script assert db_found_workflow_script == db_created_workflow_script diff --git a/tests/network/test_rabbitmq.py b/tests/network/test_integration_2_rabbitmq.py similarity index 55% rename from tests/network/test_rabbitmq.py rename to tests/network/test_integration_2_rabbitmq.py index 951266e5d..4a5239fee 100644 --- a/tests/network/test_rabbitmq.py +++ b/tests/network/test_integration_2_rabbitmq.py @@ -9,71 +9,55 @@ def test_rmq_publish_then_consume_2_messages(rabbitmq_publisher, rabbitmq_consumer): test_headers = {"Test Header": "Test Value"} test_properties = BasicProperties( - app_id='webapi-processing-broker', - content_type='application/json', + app_id="webapi-processing-broker", + content_type="application/json", headers=test_headers ) + message1 = "RabbitMQ test 123" + message2 = "RabbitMQ test 456" rabbitmq_publisher.publish_to_queue( - queue_name=DEFAULT_QUEUE, - message="RabbitMQ test 123", - exchange_name=DEFAULT_EXCHANGER_NAME, - properties=test_properties + queue_name=DEFAULT_QUEUE, message=message1, exchange_name=DEFAULT_EXCHANGER_NAME, properties=test_properties ) rabbitmq_publisher.publish_to_queue( - queue_name=DEFAULT_QUEUE, - message="RabbitMQ test 456", - exchange_name=DEFAULT_EXCHANGER_NAME, - properties=test_properties + queue_name=DEFAULT_QUEUE, message=message2, exchange_name=DEFAULT_EXCHANGER_NAME, properties=test_properties ) assert rabbitmq_publisher.message_counter == 2 # Consume the 1st message - method_frame, header_frame, message = rabbitmq_consumer.get_one_message( - queue_name=DEFAULT_QUEUE, - auto_ack=True - ) - assert method_frame.delivery_tag == 1 # 1st delivered message to this queue + method_frame, header_frame, message = rabbitmq_consumer.get_one_message(queue_name=DEFAULT_QUEUE, auto_ack=True) assert method_frame.message_count == 1 # messages left in the queue assert method_frame.redelivered is False assert method_frame.exchange == DEFAULT_EXCHANGER_NAME assert method_frame.routing_key == DEFAULT_QUEUE # It's possible to assert header_frame the same way - assert message.decode() == "RabbitMQ test 123" + assert message.decode() == message1 # Consume the 2nd message - method_frame, header_frame, message = rabbitmq_consumer.get_one_message( - queue_name=DEFAULT_QUEUE, - auto_ack=True - ) - assert method_frame.delivery_tag == 2 # 2nd delivered message to this queue + method_frame, header_frame, message = rabbitmq_consumer.get_one_message(queue_name=DEFAULT_QUEUE, auto_ack=True) assert method_frame.message_count == 0 # messages left in the queue assert method_frame.redelivered is False assert method_frame.exchange == DEFAULT_EXCHANGER_NAME assert method_frame.routing_key == DEFAULT_QUEUE # It's possible to assert header_frame the same way - assert message.decode() == "RabbitMQ test 456" + assert message.decode() == message2 def test_rmq_publish_then_consume_ocrd_message(rabbitmq_publisher, rabbitmq_consumer): - ocrd_processing_message = { - "job_id": "Test_job_id", - "workflow_id": "Test_workflow_id", - "workspace_id": "Test_workspace_id" - } + test_job_id = "test_job_id" + test_wf_id = "test_workflow_id" + test_ws_id = "test_ws_id" + ocrd_processing_message = {"job_id": test_job_id, "workflow_id": test_wf_id, "workspace_id": test_ws_id} message_bytes = dumps(ocrd_processing_message) rabbitmq_publisher.publish_to_queue( - queue_name=DEFAULT_QUEUE, - message=message_bytes, - exchange_name=DEFAULT_EXCHANGER_NAME, - properties=None + queue_name=DEFAULT_QUEUE, message=message_bytes, exchange_name=DEFAULT_EXCHANGER_NAME, properties=None ) - method_frame, header_frame, message = rabbitmq_consumer.get_one_message( - queue_name=DEFAULT_QUEUE, - auto_ack=True - ) + method_frame, header_frame, message = rabbitmq_consumer.get_one_message(queue_name=DEFAULT_QUEUE, auto_ack=True) assert method_frame.message_count == 0 # messages left in the queue + assert method_frame.redelivered is False + assert method_frame.exchange == DEFAULT_EXCHANGER_NAME + assert method_frame.routing_key == DEFAULT_QUEUE decoded_message = loads(message) - assert decoded_message["job_id"] == "Test_job_id" - assert decoded_message["workflow_id"] == "Test_workflow_id" - assert decoded_message["workspace_id"] == "Test_workspace_id" + assert decoded_message["job_id"] == test_job_id + assert decoded_message["workflow_id"] == test_wf_id + assert decoded_message["workspace_id"] == test_ws_id diff --git a/tests/network/test_integration_3_server_cache_requests.py b/tests/network/test_integration_3_server_cache_requests.py new file mode 100644 index 000000000..a4c1b6c33 --- /dev/null +++ b/tests/network/test_integration_3_server_cache_requests.py @@ -0,0 +1,167 @@ +from typing import List +from src.ocrd_network.constants import JobState +from src.ocrd_network.database import ( + sync_db_create_processing_job, + sync_db_get_processing_job, + sync_db_update_processing_job +) +from src.ocrd_network.models import DBProcessorJob, PYJobInput +from src.ocrd_network.server_cache import CacheProcessingRequests +from src.ocrd_network.utils import generate_id + + +def test_update_request_counter(): + requests_cache = CacheProcessingRequests() + workspace_key = "/path/to/mets.xml" + requests_cache.update_request_counter(workspace_key=workspace_key, by_value=0) + assert requests_cache.processing_counter[workspace_key] == 0 + requests_cache.update_request_counter(workspace_key=workspace_key, by_value=3) + assert requests_cache.processing_counter[workspace_key] == 3 + requests_cache.update_request_counter(workspace_key=workspace_key, by_value=-1) + requests_cache.update_request_counter(workspace_key=workspace_key, by_value=-1) + requests_cache.update_request_counter(workspace_key=workspace_key, by_value=-1) + assert requests_cache.processing_counter[workspace_key] == 0 + + +def test_cache_request(processing_request_1: PYJobInput): + requests_cache = CacheProcessingRequests() + workspace_key = "/path/to/mets.xml" + requests_cache.cache_request(workspace_key=workspace_key, data=processing_request_1) + requests_cache.cache_request(workspace_key=workspace_key, data=processing_request_1) + # two cached requests for the workspace key entry + assert len(requests_cache.processing_requests[workspace_key]) == 2 + # one workspace key entry in the processing requests cache + assert len(requests_cache.processing_requests) == 1 + + +def test_has_workspace_cached_requests(processing_request_1: PYJobInput): + requests_cache = CacheProcessingRequests() + workspace_key = "/path/to/mets.xml" + processing_request_1.path_to_mets = workspace_key + assert not requests_cache.has_workspace_cached_requests(workspace_key=workspace_key) + requests_cache.cache_request(workspace_key=workspace_key, data=processing_request_1) + assert requests_cache.has_workspace_cached_requests(workspace_key=workspace_key) + assert not requests_cache.has_workspace_cached_requests(workspace_key="non-existing") + + +def create_processing_requests_list(workspace_key: str) -> List[PYJobInput]: + processing_request_1 = PYJobInput( + processor_name="processor_name_1", + path_to_mets=workspace_key, + input_file_grps=["DEFAULT"], + output_file_grps=["OCR-D-BIN"], + page_id="PHYS_0001..PHYS_0003", + job_id=generate_id(), + depends_on=[] + ) + processing_request_2 = PYJobInput( + processor_name="processor_name_2", + path_to_mets=workspace_key, + input_file_grps=["OCR-D-BIN"], + output_file_grps=["OCR-D-CROP"], + page_id="PHYS_0001..PHYS_0003", + job_id=generate_id(), + depends_on=[processing_request_1.job_id] + ) + processing_request_3 = PYJobInput( + processor_name="processor_name_3", + path_to_mets=workspace_key, + input_file_grps=["OCR-D-CROP"], + output_file_grps=["OCR-D-BIN2"], + page_id="PHYS_0001..PHYS_0003", + job_id=generate_id(), + depends_on=[processing_request_2.job_id] + ) + processing_request_4 = PYJobInput( + processor_name="processor_name_4", + path_to_mets=workspace_key, + input_file_grps=["OCR-D-BIN2"], + output_file_grps=["OCR-D-BIN-DENOISE"], + page_id="PHYS_0001..PHYS_0003", + job_id=generate_id(), + depends_on=[processing_request_2.job_id] + ) + + processing_requests = [processing_request_1, processing_request_2, processing_request_3, processing_request_4] + return processing_requests + + +def create_processing_jobs_db_entries( + requests_cache: CacheProcessingRequests, workspace_key: str +) -> List[DBProcessorJob]: + processing_requests_list = create_processing_requests_list(workspace_key=workspace_key) + jobs_list = [] + # Insert processing jobs into the database based on processing requests + for processing_request in processing_requests_list: + requests_cache.cache_request(workspace_key=workspace_key, data=processing_request) + db_processing_job = DBProcessorJob( + **processing_request.dict(exclude_unset=True, exclude_none=True), + state=JobState.cached + ) + sync_db_create_processing_job(db_processing_job) + assert db_processing_job.state == JobState.cached + jobs_list.append(db_processing_job) + return jobs_list + + +def test_is_caching_required(): + requests_cache = CacheProcessingRequests() + workspace_key = "/path/to/mets.xml" + jobs_list = create_processing_jobs_db_entries(requests_cache=requests_cache, workspace_key=workspace_key) + + # depends on nothing, should not be cached + assert not requests_cache.sync_is_caching_required(job_dependencies=jobs_list[0].depends_on) + # depends on processing_job_1, should be cached + assert requests_cache.sync_is_caching_required(job_dependencies=jobs_list[1].depends_on) + # depends on processing_job_2, should be cached + assert requests_cache.sync_is_caching_required(job_dependencies=jobs_list[2].depends_on) + # depends on processing_job_2, should be cached + assert requests_cache.sync_is_caching_required(job_dependencies=jobs_list[3].depends_on) + + sync_db_update_processing_job(jobs_list[0].job_id, state=JobState.success) + # the dependent job has successfully finished, no caching required + assert not requests_cache.sync_is_caching_required(job_dependencies=jobs_list[1].depends_on) + sync_db_update_processing_job(jobs_list[1].job_id, state=JobState.success) + # the dependent job has successfully finished, no caching required for job 3 and job 4 + assert not requests_cache.sync_is_caching_required(job_dependencies=jobs_list[2].depends_on) + assert not requests_cache.sync_is_caching_required(job_dependencies=jobs_list[3].depends_on) + + +def test_cancel_dependent_jobs(): + requests_cache = CacheProcessingRequests() + # Must match with the workspace_key in the processing_jobs_list + workspace_key = "/path/to/mets.xml" + jobs_list = create_processing_jobs_db_entries(requests_cache=requests_cache, workspace_key=workspace_key) + + db_processing_job_1 = sync_db_update_processing_job(jobs_list[0].job_id, state=JobState.failed) + assert db_processing_job_1.state == JobState.failed + requests_cache.sync_cancel_dependent_jobs(workspace_key=workspace_key, processing_job_id=jobs_list[0].job_id) + db_processing_job_2 = sync_db_get_processing_job(job_id=jobs_list[1].job_id) + db_processing_job_3 = sync_db_get_processing_job(job_id=jobs_list[2].job_id) + db_processing_job_4 = sync_db_get_processing_job(job_id=jobs_list[3].job_id) + # job 2 is cancelled because job 1 has failed + assert db_processing_job_2.state == JobState.cancelled + # job 3 and job 4 are cancelled because job 2 got cancelled + assert db_processing_job_3.state == JobState.cancelled + assert db_processing_job_4.state == JobState.cancelled + + +def test_consume_cached_requests(): + requests_cache = CacheProcessingRequests() + # Must match with the workspace_key in the processing_jobs_list + workspace_key = "/path/to/mets.xml" + jobs_list = create_processing_jobs_db_entries(requests_cache=requests_cache, workspace_key=workspace_key) + + db_processing_job_1 = sync_db_update_processing_job(jobs_list[0].job_id, state=JobState.success) + assert db_processing_job_1.state == JobState.success + # Remove the job 1 since it's no longer cached, but manually set to success + requests_cache.processing_requests[workspace_key].pop(0) + # Consumes only processing job 2 since only that job's dependencies (i.e., job 1) have succeeded + consumed_jobs = requests_cache.sync_consume_cached_requests(workspace_key=workspace_key) + assert len(consumed_jobs) == 1 + + db_processing_job_2 = sync_db_update_processing_job(jobs_list[1].job_id, state=JobState.success) + assert db_processing_job_2.state == JobState.success + # Consumes processing job 3 and job 4 since they depend on job 2 + consumed_jobs = requests_cache.sync_consume_cached_requests(workspace_key=workspace_key) + assert len(consumed_jobs) == 2 diff --git a/tests/network/test_integration_4_processing_worker.py b/tests/network/test_integration_4_processing_worker.py new file mode 100644 index 000000000..e211bd238 --- /dev/null +++ b/tests/network/test_integration_4_processing_worker.py @@ -0,0 +1,118 @@ +from pathlib import Path +from pika import BasicProperties +from src.ocrd.processor.builtin.dummy_processor import DummyProcessor, OCRD_TOOL +from src.ocrd_network.constants import JobState +from src.ocrd_network.database import sync_db_create_workspace, sync_db_create_processing_job +from src.ocrd_network.logging_utils import get_processing_job_logging_file_path +from src.ocrd_network.models import DBProcessorJob +from src.ocrd_network.processing_worker import ProcessingWorker +from src.ocrd_network.rabbitmq_utils import OcrdProcessingMessage, OcrdResultMessage +from src.ocrd_network.utils import generate_created_time, generate_id +from tests.base import assets +from tests.network.config import test_config + + +def test_processing_worker_process_message(): + workspace_root = "kant_aufklaerung_1784/data" + path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") + assert Path(path_to_mets).exists() + test_job_id = generate_id() + test_created_time = generate_created_time() + input_file_grp = "OCR-D-IMG" + output_file_grp = f"OCR-D-DUMMY-TEST-WORKER-{test_job_id}" + page_id = "PHYS_0017,PHYS_0020" + # Notice, the name is intentionally set differently from "ocrd-dummy" to prevent + # wrong reads from the deployed dummy worker (part of the processing server integration test) + processor_name = "ocrd-dummy-test" + result_queue_name = f"{processor_name}-result" + + processing_worker = ProcessingWorker( + rabbitmq_addr=test_config.RABBITMQ_URL, + mongodb_addr=test_config.DB_URL, + processor_name=processor_name, + ocrd_tool=OCRD_TOOL, + processor_class=DummyProcessor + ) + processing_worker.connect_publisher(enable_acks=True) + assert processing_worker.rmq_publisher + processing_worker.connect_consumer() + assert processing_worker.rmq_consumer + + # Create the workspace DB entry if not already existing + sync_db_create_workspace(mets_path=path_to_mets) + # Create the processing job DB entry + sync_db_create_processing_job( + db_processing_job=DBProcessorJob( + job_id=test_job_id, + processor_name=processor_name, + created_time=test_created_time, + path_to_mets=path_to_mets, + workspace_id=None, + input_file_grps=[input_file_grp], + output_file_grps=[output_file_grp], + page_id=page_id, + parameters={}, + result_queue_name=result_queue_name, + callback_url=None, + internal_callback_url=None + ) + ) + + # PUSH/Publish the ocrd processing message + ocrd_processing_message = OcrdProcessingMessage( + job_id=test_job_id, + processor_name=processor_name, + created_time=test_created_time, + path_to_mets=path_to_mets, + workspace_id=None, + input_file_grps=[input_file_grp], + output_file_grps=[output_file_grp], + page_id=page_id, + parameters={}, + result_queue_name=result_queue_name, + callback_url=None, + internal_callback_url=None + ) + encoded_message = OcrdProcessingMessage.encode_yml(ocrd_processing_message) + test_properties = BasicProperties( + app_id="ocrd_network_testing", + content_type="application/json", + headers={"Test Header": "Test Value"} + ) + # Push the ocrd processing message to the RabbitMQ + processing_worker.rmq_publisher.publish_to_queue( + queue_name=processor_name, message=encoded_message, properties=test_properties + ) + # The queue should have a single message inside + assert processing_worker.rmq_publisher.message_counter == 1 + + # PULL/Consume the ocrd processing message + method_frame, header_frame, processing_message = processing_worker.rmq_consumer.get_one_message( + queue_name=processor_name, auto_ack=True + ) + assert method_frame.message_count == 0 # Messages left in the queue + assert method_frame.redelivered is False + assert method_frame.routing_key == processor_name + + decoded_processing_message = OcrdProcessingMessage.decode_yml(ocrd_processing_message=processing_message) + + # Process the ocrd processing message + processing_worker.process_message(processing_message=decoded_processing_message) + + # Check the existence of the results locally + assert Path(assets.path_to(f"{workspace_root}/{output_file_grp}")).exists() + path_to_log_file = get_processing_job_logging_file_path(job_id=test_job_id) + assert Path(path_to_log_file).exists() + + # PULL/Consume the ocrd result message for verification (pushed by the process_message method) + method_frame, header_frame, result_message = processing_worker.rmq_consumer.get_one_message( + queue_name=result_queue_name, auto_ack=True + ) + assert method_frame.message_count == 0 # Messages left in the queue + assert method_frame.redelivered is False + assert method_frame.routing_key == result_queue_name + + decoded_result_message = OcrdResultMessage.decode_yml(result_message) + assert decoded_result_message.job_id == test_job_id + assert decoded_result_message.state == JobState.success + assert decoded_result_message.path_to_mets == path_to_mets diff --git a/tests/network/test_integration_5_processing_server.py b/tests/network/test_integration_5_processing_server.py new file mode 100644 index 000000000..5b22e6cc6 --- /dev/null +++ b/tests/network/test_integration_5_processing_server.py @@ -0,0 +1,107 @@ +from pathlib import Path +from requests import get as request_get, post as request_post +from time import sleep +from src.ocrd_network.constants import AgentType, JobState +from src.ocrd_network.logging_utils import get_processing_job_logging_file_path +from tests.base import assets +from tests.network.config import test_config + +PROCESSING_SERVER_URL = test_config.PROCESSING_SERVER_URL + + +def poll_till_timeout_fail_or_success(test_url: str, tries: int, wait: int) -> JobState: + job_state = JobState.unset + while tries > 0: + sleep(wait) + response = request_get(url=test_url) + assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" + job_state = response.json()["state"] + if job_state == JobState.success or job_state == JobState.failed: + break + tries -= 1 + return job_state + + +def test_processing_server_connectivity(): + test_url = f"{PROCESSING_SERVER_URL}/" + response = request_get(test_url) + assert response.status_code == 200, f"Processing server is not reachable on: {test_url}, {response.status_code}" + message = response.json()["message"] + assert message.startswith("The home page of"), f"Processing server home page message is corrupted" + + +# TODO: The processing workers are still not registered when deployed separately. +# Fix that by extending the processing server. +def test_processing_server_deployed_processors(): + test_url = f"{PROCESSING_SERVER_URL}/processor" + response = request_get(test_url) + processors = response.json() + assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" + assert processors == [], f"Mismatch in deployed processors" + + +def test_processing_server_processing_request(): + workspace_root = "kant_aufklaerung_1784/data" + path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") + input_file_grp = "OCR-D-IMG" + output_file_grp = f"OCR-D-DUMMY-TEST-PS" + test_processing_job_input = { + "path_to_mets": path_to_mets, + "input_file_grps": [input_file_grp], + "output_file_grps": [output_file_grp], + "agent_type": AgentType.PROCESSING_WORKER, + "parameters": {} + } + test_processor = "ocrd-dummy" + test_url = f"{PROCESSING_SERVER_URL}/processor/run/{test_processor}" + response = request_post( + url=test_url, + headers={"accept": "application/json"}, + json=test_processing_job_input + ) + print(response.json()) + print(response.__dict__) + assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" + processing_job_id = response.json()["job_id"] + assert processing_job_id + + job_state = poll_till_timeout_fail_or_success( + test_url=f"{PROCESSING_SERVER_URL}/processor/job/{processing_job_id}", tries=10, wait=10 + ) + assert job_state == JobState.success + + # Check the existence of the results locally + # assert Path(assets.path_to(f"{workspace_root}/{output_file_grp}")).exists() + # path_to_log_file = get_processing_job_logging_file_path(job_id=processing_job_id) + # assert Path(path_to_log_file).exists() + + +def test_processing_server_workflow_request(): + # Note: the used workflow path is volume mapped + path_to_dummy_wf = "/ocrd-data/assets/dummy-workflow.txt" + workspace_root = "kant_aufklaerung_1784/data" + path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") + + # submit the workflow job + test_url = f"{PROCESSING_SERVER_URL}/workflow/run?mets_path={path_to_mets}&page_wise=True" + response = request_post( + url=test_url, + headers={"accept": "application/json"}, + files={"workflow": open(path_to_dummy_wf, 'rb')} + ) + # print(response.json()) + # print(response.__dict__) + assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" + wf_job_id = response.json()["job_id"] + assert wf_job_id + + job_state = poll_till_timeout_fail_or_success( + test_url=f"{PROCESSING_SERVER_URL}/workflow/job-simple/{wf_job_id}", tries=30, wait=10 + ) + assert job_state == JobState.success + + # Check the existence of the results locally + # The output file groups are defined in the `path_to_dummy_wf` + # assert Path(assets.path_to(f"{workspace_root}/OCR-D-DUMMY1")).exists() + # assert Path(assets.path_to(f"{workspace_root}/OCR-D-DUMMY2")).exists() + # assert Path(assets.path_to(f"{workspace_root}/OCR-D-DUMMY3")).exists() diff --git a/tests/network/test_modules_logging_utils.py b/tests/network/test_modules_logging_utils.py new file mode 100644 index 000000000..530b501e0 --- /dev/null +++ b/tests/network/test_modules_logging_utils.py @@ -0,0 +1,34 @@ +from pathlib import Path +from src.ocrd_network.constants import NetworkLoggingDirs +from src.ocrd_network.logging_utils import ( + get_root_logging_dir +) +from tests.network.config import test_config + +OCRD_NETWORK_LOGS_ROOT_DIR = test_config.OCRD_NETWORK_LOGS_ROOT_DIR + + +def root_logging_dir(module_name: NetworkLoggingDirs): + func_result = get_root_logging_dir(module_name=module_name) + expected_result = Path(OCRD_NETWORK_LOGS_ROOT_DIR, module_name.value) + assert func_result == expected_result, f"Mismatch in root logging dir of module: {module_name.value}" + + +def test_root_logging_dir_mets_servers(): + root_logging_dir(module_name=NetworkLoggingDirs.METS_SERVERS) + + +def test_root_logging_dir_processor_servers(): + root_logging_dir(module_name=NetworkLoggingDirs.PROCESSOR_SERVERS) + + +def test_root_logging_dir_processing_workers(): + root_logging_dir(module_name=NetworkLoggingDirs.PROCESSING_WORKERS) + + +def test_root_logging_dir_processing_servers(): + root_logging_dir(module_name=NetworkLoggingDirs.PROCESSING_SERVERS) + + +def test_root_logging_dir_processing_jobs(): + root_logging_dir(module_name=NetworkLoggingDirs.PROCESSING_JOBS) diff --git a/tests/network/test_modules_param_validators.py b/tests/network/test_modules_param_validators.py new file mode 100644 index 000000000..2c78e92fc --- /dev/null +++ b/tests/network/test_modules_param_validators.py @@ -0,0 +1,68 @@ +from pytest import raises +from src.ocrd_network.param_validators import DatabaseParamType, ServerAddressParamType, QueueServerParamType + + +def test_database_param_type_positive(): + database_param_type = DatabaseParamType() + correct_db_uris = [ + f"mongodb://db_user:db_pass@localhost:27017/", + f"mongodb://db_user:db_pass@localhost:27017", + f"mongodb://localhost:27017/", + f"mongodb://localhost:27017", + f"mongodb://localhost" + ] + for db_uri in correct_db_uris: + database_param_type.convert(value=db_uri, param=None, ctx=None) + + +def test_database_param_type_negative(): + database_param_type = DatabaseParamType() + incorrect_db_uris = [ + f"mongodbb://db_user:db_pass@localhost:27017", + f"://db_user:db_pass@localhost:27017", + f"db_user:db_pass@localhost:27017", + f"localhost:27017", + "localhost" + ] + for db_uri in incorrect_db_uris: + with raises(Exception): + database_param_type.convert(value=db_uri, param=None, ctx=None) + + +def test_queue_server_param_type_positive(): + rmq_server_param_type = QueueServerParamType() + correct_rmq_uris = [ + f"amqp://rmq_user:rmq_pass@localhost:5672/", + f"amqp://rmq_user:rmq_pass@localhost:5672", + f"amqp://localhost:5672/", + f"amqp://localhost:5672", + f"amqp://localhost" + ] + for rmq_uri in correct_rmq_uris: + rmq_server_param_type.convert(value=rmq_uri, param=None, ctx=None) + + +def test_queue_server_param_type_negative(): + rmq_server_param_type = QueueServerParamType() + incorrect_rmq_uris = [ + f"amqpp://rmq_user:rmq_pass@localhost:5672", + f"rmq_user:rmq_pass@localhost:5672", + f"localhost:5672", + "localhost" + ] + for rmq_uri in incorrect_rmq_uris: + with raises(Exception): + rmq_server_param_type.convert(value=rmq_uri, param=None, ctx=None) + + +def test_server_address_param_type_positive(): + server_address_param_type = ServerAddressParamType() + correct_address = "localhost:8000" + server_address_param_type.convert(value=correct_address, param=None, ctx=None) + + +def test_server_address_param_type_negative(): + server_address_param_type = ServerAddressParamType() + incorrect_address = "8000:localhost" + with raises(Exception): + server_address_param_type.convert(value=incorrect_address, param=None, ctx=None) diff --git a/tests/network/test_modules_process_helpers.py b/tests/network/test_modules_process_helpers.py new file mode 100644 index 000000000..2e9ab8470 --- /dev/null +++ b/tests/network/test_modules_process_helpers.py @@ -0,0 +1,75 @@ +from contextlib import contextmanager +from os import environ +from pathlib import Path + +from ocrd.processor.builtin.dummy_processor import DummyProcessor +from ocrd_network.constants import NetworkLoggingDirs +from ocrd_network.logging_utils import get_root_logging_dir +from ocrd_network.process_helpers import invoke_processor +from ocrd_network.utils import generate_id + +from tests.base import assets + +@contextmanager +def temp_env_var(k, v): + v_before = environ.get(k, None) + environ[k] = v + yield + if v_before is not None: + environ[k] = v_before + else: + del environ[k] + + +def test_invoke_processor_bash(): + scriptdir = Path(__file__).parent.parent / 'data' + with temp_env_var('PATH', f'{scriptdir}:{environ["PATH"]}'): + workspace_root = "kant_aufklaerung_1784/data" + path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") + assert Path(path_to_mets).exists() + log_dir_root = get_root_logging_dir(module_name=NetworkLoggingDirs.PROCESSING_JOBS) + job_id = generate_id() + path_to_log_file = Path(log_dir_root, job_id) + input_file_grp = "OCR-D-IMG" + output_file_grp = f"OCR-D-BASH-TEST-{job_id}" + try: + invoke_processor( + processor_class=None, # required only for pythonic processors + executable='ocrd-cp', + abs_path_to_mets=path_to_mets, + input_file_grps=[input_file_grp], + output_file_grps=[output_file_grp], + page_id="PHYS_0017,PHYS_0020", + parameters={}, + log_filename=path_to_log_file, + log_level="DEBUG" + ) + except: + with open(path_to_log_file, 'r', encoding='utf-8') as f: + print(f.read()) + assert Path(assets.path_to(f"{workspace_root}/{output_file_grp}")).exists() + assert Path(path_to_log_file).exists() + + +def test_invoke_processor_pythonic(): + workspace_root = "kant_aufklaerung_1784/data" + path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") + assert Path(path_to_mets).exists() + log_dir_root = get_root_logging_dir(module_name=NetworkLoggingDirs.PROCESSING_JOBS) + job_id = generate_id() + path_to_log_file = Path(log_dir_root, job_id) + input_file_grp = "OCR-D-IMG" + output_file_grp = f"OCR-D-DUMMY-TEST-{job_id}" + invoke_processor( + processor_class=DummyProcessor, + executable="", # not required for pythonic processors + abs_path_to_mets=path_to_mets, + input_file_grps=[input_file_grp], + output_file_grps=[output_file_grp], + page_id="PHYS_0017,PHYS_0020", + parameters={}, + log_filename=path_to_log_file, + log_level="DEBUG" + ) + assert Path(assets.path_to(f"{workspace_root}/{output_file_grp}")).exists() + assert Path(path_to_log_file).exists() diff --git a/tests/network/test_modules_server_cache_pages.py b/tests/network/test_modules_server_cache_pages.py new file mode 100644 index 000000000..331a73844 --- /dev/null +++ b/tests/network/test_modules_server_cache_pages.py @@ -0,0 +1,135 @@ +from typing import List +from src.ocrd_network.server_cache import CacheLockedPages + + +def assert_locked_all_pages(pages_cache: CacheLockedPages, workspace_key: str, output_file_grps: List[str]): + assert len(pages_cache.locked_pages) == 1 + ws_locked_pages_dict = pages_cache.locked_pages[workspace_key] + assert len(ws_locked_pages_dict) == len(output_file_grps) + for output_file_group in output_file_grps: + # The array contains a single element - the placeholder indicating all pages + assert len(ws_locked_pages_dict[output_file_group]) == 1 + assert ws_locked_pages_dict[output_file_group][0] == pages_cache.placeholder_all_pages + + +def assert_unlocked_all_pages(pages_cache: CacheLockedPages, workspace_key: str, output_file_grps: List[str]): + assert len(pages_cache.locked_pages) == 1 + assert len(pages_cache.locked_pages[workspace_key]) == len(output_file_grps) + for output_file_group in output_file_grps: + assert len(pages_cache.locked_pages[workspace_key][output_file_group]) == 0 + + +def assert_locked_some_pages( + pages_cache: CacheLockedPages, workspace_key: str, output_file_grps: List[str], page_ids: List[str] +): + assert len(pages_cache.locked_pages) == 1 + ws_locked_pages_dict = pages_cache.locked_pages[workspace_key] + assert len(ws_locked_pages_dict) == len(output_file_grps) + for output_file_group in output_file_grps: + assert len(ws_locked_pages_dict[output_file_group]) == len(page_ids) + for page_id in page_ids: + assert ws_locked_pages_dict[output_file_group].count(page_id) == 1 + + +def assert_unlocked_some_pages( + pages_cache: CacheLockedPages, workspace_key: str, output_file_grps: List[str], page_ids: List[str] +): + assert len(pages_cache.locked_pages) == 1 + ws_locked_pages_dict = pages_cache.locked_pages[workspace_key] + assert len(ws_locked_pages_dict) == len(output_file_grps) + for output_file_group in output_file_grps: + assert len(ws_locked_pages_dict[output_file_group]) == 0 + for page_id in page_ids: + assert ws_locked_pages_dict[output_file_group].count(page_id) == 0 + + +def test_lock_all_pages(): + workspace_key: str = "test_workspace" + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=[]) + assert_locked_all_pages(pages_cache, workspace_key, output_file_grps) + + +def test_unlock_all_pages(): + workspace_key: str = "test_workspace" + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=[]) + assert_locked_all_pages(pages_cache, workspace_key, output_file_grps) + pages_cache.unlock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=[]) + assert_unlocked_all_pages(pages_cache, workspace_key, output_file_grps) + + +def test_lock_some_pages(): + workspace_key: str = "test_workspace" + # Output file groups whose pages are to be locked + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + # Pages to be locked for each output file group + page_ids: List[str] = ["PHYS_0001", "PHYS_0002", "PHYS_0003", "PHYS_0004"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=page_ids) + assert_locked_some_pages(pages_cache, workspace_key, output_file_grps, page_ids) + + +def test_unlock_some_pages(): + workspace_key: str = "test_workspace" + # Output file groups whose pages are to be locked + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + # Pages to be locked for each output file group + page_ids: List[str] = ["PHYS_0001", "PHYS_0002", "PHYS_0003", "PHYS_0004"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=page_ids) + assert_locked_some_pages(pages_cache, workspace_key, output_file_grps, page_ids) + pages_cache.unlock_pages(workspace_key, output_file_grps, page_ids) + assert_unlocked_some_pages(pages_cache, workspace_key, output_file_grps, page_ids) + + +def test_get_locked_pages(): + workspace_key: str = "test_workspace" + # Output file groups whose pages are to be locked + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + # Pages to be locked for each output file group + page_ids: List[str] = ["PHYS_0001", "PHYS_0002", "PHYS_0003", "PHYS_0004"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=page_ids) + assert_locked_some_pages(pages_cache, workspace_key, output_file_grps, page_ids) + assert pages_cache.get_locked_pages(workspace_key=workspace_key) == pages_cache.locked_pages[workspace_key] + + +def test_check_if_locked_pages_for_output_file_grps(): + workspace_key: str = "test_workspace" + # Output file groups whose pages are to be locked + output_file_grps: List[str] = ["OCR-D-IMG", "OCR-D-BIN"] + # Pages to be locked for each output file group + page_ids: List[str] = ["PHYS_0001", "PHYS_0002", "PHYS_0003", "PHYS_0004"] + + pages_cache = CacheLockedPages() + pages_cache.lock_pages(workspace_key=workspace_key, output_file_grps=output_file_grps, page_ids=page_ids) + assert_locked_some_pages(pages_cache, workspace_key, output_file_grps, page_ids) + + # Test for locked pages + assert pages_cache.check_if_locked_pages_for_output_file_grps( + workspace_key, output_file_grps=["OCR-D-IMG"], page_ids=["PHYS_0001", "PHYS_0002"] + ) + assert pages_cache.check_if_locked_pages_for_output_file_grps( + workspace_key, output_file_grps=["OCR-D-BIN"], page_ids=["PHYS_0003", "PHYS_0004"] + ) + + # Test for non-locked pages + assert not pages_cache.check_if_locked_pages_for_output_file_grps( + workspace_key, output_file_grps=["OCR-D-IMG"], page_ids=["PHYS_0010", "PHYS_0011"] + ) + assert not pages_cache.check_if_locked_pages_for_output_file_grps( + workspace_key, output_file_grps=["OCR-D-BIN"], page_ids=["PHYS_0010", "PHYS_0011"] + ) + + # Test for non-existing output file group + assert not pages_cache.check_if_locked_pages_for_output_file_grps( + workspace_key, output_file_grps=["OCR-D-OCR"], page_ids=["PHYS_0001", "PHYS_0002"] + ) diff --git a/tests/network/test_processing_server.py b/tests/network/test_processing_server.py deleted file mode 100644 index 6d039f5bc..000000000 --- a/tests/network/test_processing_server.py +++ /dev/null @@ -1,97 +0,0 @@ -from time import sleep -from requests import get, post -from src.ocrd_network import NETWORK_AGENT_WORKER -from src.ocrd_network.models import StateEnum -from tests.base import assets -from tests.network.config import test_config - -PROCESSING_SERVER_URL = test_config.PROCESSING_SERVER_URL - - -def poll_till_timeout_fail_or_success(test_url: str, tries: int, wait: int) -> StateEnum: - job_state = StateEnum.unset - while tries > 0: - sleep(wait) - response = get(url=test_url) - assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" - job_state = response.json()["state"] - if job_state == StateEnum.success or job_state == StateEnum.failed: - break - tries -= 1 - return job_state - - -def test_processing_server_connectivity(): - test_url = f'{PROCESSING_SERVER_URL}/' - response = get(test_url) - assert response.status_code == 200, \ - f'Processing server is not reachable on: {test_url}, {response.status_code}' - message = response.json()['message'] - assert message.startswith('The home page of'), \ - f'Processing server home page message is corrupted' - - -# TODO: The processing workers are still not registered when deployed separately. -# Fix that by extending the processing server. -def test_processing_server_deployed_processors(): - test_url = f'{PROCESSING_SERVER_URL}/processor' - response = get(test_url) - processors = response.json() - assert response.status_code == 200, \ - f'Processing server: {test_url}, {response.status_code}' - assert processors == [], f'Mismatch in deployed processors' - - -def test_processing_server_processing_request(): - path_to_mets = assets.path_to('kant_aufklaerung_1784/data/mets.xml') - test_processing_job_input = { - "path_to_mets": path_to_mets, - "input_file_grps": ['OCR-D-IMG'], - "output_file_grps": ['OCR-D-DUMMY'], - "agent_type": NETWORK_AGENT_WORKER, - "parameters": {} - } - test_processor = 'ocrd-dummy' - test_url = f'{PROCESSING_SERVER_URL}/processor/run/{test_processor}' - response = post( - url=test_url, - headers={"accept": "application/json"}, - json=test_processing_job_input - ) - # print(response.json()) - assert response.status_code == 200, \ - f'Processing server: {test_url}, {response.status_code}' - processing_job_id = response.json()["job_id"] - assert processing_job_id - - job_state = poll_till_timeout_fail_or_success( - test_url=f"{PROCESSING_SERVER_URL}/processor/job/{processing_job_id}", - tries=10, - wait=10 - ) - assert job_state == StateEnum.success - - -def test_processing_server_workflow_request(): - # Note: the used workflow path is volume mapped - path_to_dummy_wf = "/ocrd-data/assets/dummy-workflow.txt" - path_to_mets = assets.path_to('kant_aufklaerung_1784/data/mets.xml') - - # submit the workflow job - test_url = f"{PROCESSING_SERVER_URL}/workflow/run?mets_path={path_to_mets}&page_wise=True" - response = post( - url=test_url, - headers={"accept": "application/json"}, - files={"workflow": open(path_to_dummy_wf, 'rb')} - ) - # print(response.json()) - assert response.status_code == 200, f"Processing server: {test_url}, {response.status_code}" - wf_job_id = response.json()["job_id"] - assert wf_job_id - - job_state = poll_till_timeout_fail_or_success( - test_url=f"{PROCESSING_SERVER_URL}/workflow/job-simple/{wf_job_id}", - tries=30, - wait=10 - ) - assert job_state == StateEnum.success