Skip to content

Commit

Permalink
another portion of improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
MehmedGIT committed Feb 19, 2024
1 parent 59e14b1 commit 0cb258e
Show file tree
Hide file tree
Showing 27 changed files with 494 additions and 545 deletions.
2 changes: 1 addition & 1 deletion src/ocrd/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
\b
{config.describe('OCRD_NETWORK_SERVER_ADDR_WORKSPACE')}
\b
{config.describe('OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS')}
{config.describe('OCRD_NETWORK_RABBITMQ_CLIENT_CONNECT_ATTEMPTS')}
\b
{config.describe('OCRD_PROFILE_FILE')}
\b
Expand Down
27 changes: 27 additions & 0 deletions src/ocrd_network/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from beanie.operators import In
from motor.motor_asyncio import AsyncIOMotorClient
from pathlib import Path
from pymongo import MongoClient, uri_parser as mongo_uri_parser
from re import sub as re_sub
from typing import List
from uuid import uuid4

Expand Down Expand Up @@ -248,3 +250,28 @@ async def db_find_first_workflow_script_by_content(content_hash: str) -> DBWorkf
@call_sync
async def sync_db_find_first_workflow_script_by_content(workflow_id: str) -> DBWorkflowScript:
return await db_get_workflow_script(workflow_id)


def verify_database_uri(mongodb_address: str) -> str:
try:
# perform validation check
mongo_uri_parser.parse_uri(uri=mongodb_address, validate=True)
except Exception as error:
raise ValueError(f"The MongoDB address '{mongodb_address}' is in wrong format, {error}")
return mongodb_address


def verify_mongodb_available(mongo_url: str) -> None:
"""
# The protocol is intentionally set to HTTP instead of MONGODB!
mongodb_test_url = mongo_url.replace("mongodb", "http")
if is_url_responsive(url=mongodb_test_url, tries=3):
return
raise RuntimeError(f"Verifying connection has failed: {mongodb_test_url}")
"""

try:
client = MongoClient(mongo_url, serverSelectionTimeoutMS=60000.0)
client.admin.command("ismaster")
except Exception:
raise RuntimeError(f'Cannot connect to MongoDB: {re_sub(r":[^@]+@", ":****@", mongo_url)}')
3 changes: 2 additions & 1 deletion src/ocrd_network/param_validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from click import ParamType

from .utils import verify_database_uri, verify_and_parse_mq_uri
from .database import verify_database_uri
from .rabbitmq_utils import verify_and_parse_mq_uri


class ServerAddressParamType(ParamType):
Expand Down
185 changes: 51 additions & 134 deletions src/ocrd_network/processing_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from datetime import datetime
from httpx import AsyncClient, Timeout
from json import dumps, loads
from os import getpid
from requests import get as requests_get
from typing import Dict, List, Union
from urllib.parse import urljoin
from uvicorn import run as uvicorn_run

from fastapi import APIRouter, FastAPI, File, HTTPException, Request, status, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse
from pika.exceptions import ChannelClosedByBroker

from ocrd.task_sequence import ProcessorTask
from ocrd_utils import initLogging, getLogger, LOG_FORMAT
from ocrd_utils import initLogging, getLogger
from .constants import AgentType, JobState, OCRD_ALL_JSON_TOOLS_URL, ServerApiTags
from .database import (
initiate_database,
Expand All @@ -35,11 +30,17 @@
PYResultMessage,
PYWorkflowJobOutput
)
from .rabbitmq_utils import RMQPublisher, OcrdProcessingMessage
from .rabbitmq_utils import (
check_if_queue_exists,
connect_rabbitmq_publisher,
create_message_queues,
OcrdProcessingMessage
)
from .server_cache import CacheLockedPages, CacheProcessingRequests
from .server_utils import (
create_processing_message,
create_workspace_if_not_exists,
forward_job_to_processor_server,
_get_processor_job,
_get_processor_job_log,
get_page_ids_list,
Expand All @@ -48,13 +49,13 @@
get_from_database_workflow_job,
parse_workflow_tasks,
raise_http_exception,
request_processor_server_tool_json,
validate_and_return_mets_path,
validate_first_task_input_file_groups_existence,
validate_job_input,
validate_workflow
)
from .utils import (
calculate_processing_request_timeout,
download_ocrd_all_tool_json,
expand_page_ids,
generate_id,
Expand Down Expand Up @@ -104,12 +105,13 @@ def __init__(self, config_path: str, host: str, port: int) -> None:

self.mongodb_url = None
self.rabbitmq_url = None
# TODO: Combine these under a single URL, rabbitmq_utils needs an update
self.rmq_host = self.deployer.data_queue.host
self.rmq_port = self.deployer.data_queue.port
self.rmq_vhost = "/"
self.rmq_username = self.deployer.data_queue.cred_username
self.rmq_password = self.deployer.data_queue.cred_password
self.rmq_data = {
"host": self.deployer.data_queue.host,
"port": self.deployer.data_queue.port,
"vhost": "/",
"username": self.deployer.data_queue.cred_username,
"password": self.deployer.data_queue.cred_password
}

# Gets assigned when `connect_rabbitmq_publisher()` is called on the working object
self.rmq_publisher = None
Expand Down Expand Up @@ -139,9 +141,13 @@ def start(self) -> None:
self.mongodb_url = self.deployer.deploy_mongodb()

# The RMQPublisher is initialized and a connection to the RabbitMQ is performed
self.connect_rabbitmq_publisher()
self.rmq_publisher = connect_rabbitmq_publisher(self.log, self.rmq_data, enable_acks=True)

queue_names = self.deployer.find_matching_processors(
worker_only=True, str_names_only=True, unique_only=True
)
self.log.debug(f"Creating message queues on RabbitMQ instance url: {self.rabbitmq_url}")
self.create_message_queues()
create_message_queues(logger=self.log, rmq_publisher=self.rmq_publisher, queue_names=queue_names)

self.deployer.deploy_network_agents(mongodb_url=self.mongodb_url, rabbitmq_url=self.rabbitmq_url)
except Exception as error:
Expand Down Expand Up @@ -306,91 +312,25 @@ async def home_page(self):
async def stop_deployed_agents(self) -> None:
self.deployer.stop_all()

def connect_rabbitmq_publisher(self, enable_acks: bool = True) -> None:
self.log.info(f'Connecting RMQPublisher to RabbitMQ server: '
f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}')
self.rmq_publisher = RMQPublisher(
host=self.rmq_host,
port=self.rmq_port,
vhost=self.rmq_vhost
)
self.log.debug(f'RMQPublisher authenticates with username: '
f'{self.rmq_username}, password: {self.rmq_password}')
self.rmq_publisher.authenticate_and_connect(
username=self.rmq_username,
password=self.rmq_password
)
if enable_acks:
self.rmq_publisher.enable_delivery_confirmations()
self.log.info('Delivery confirmations are enabled')
self.log.info('Successfully connected RMQPublisher.')

def create_message_queues(self) -> None:
""" Create the message queues based on the occurrence of
`workers.name` in the config file.
"""

# The abstract version of the above lines
queue_names = self.deployer.find_matching_processors(
worker_only=True,
str_names_only=True,
unique_only=True
)

# TODO: Reconsider and refactor this.
# Added ocrd-dummy by default if not available for the integration tests.
# A proper Processing Worker / Processor Server registration endpoint is needed on the Processing Server side
if 'ocrd-dummy' not in queue_names:
queue_names.append('ocrd-dummy')

for queue_name in queue_names:
# The existence/validity of the worker.name is not tested.
# Even if an ocr-d processor does not exist, the queue is created
self.log.info(f'Creating a message queue with id: {queue_name}')
self.rmq_publisher.create_queue(queue_name=queue_name)

def check_if_queue_exists(self, processor_name: str) -> bool:
try:
# Only checks if the process queue exists, if not raises ChannelClosedByBroker
self.rmq_publisher.create_queue(processor_name, passive=True)
return True
except ChannelClosedByBroker as error:
self.log.warning(f"Process queue with id '{processor_name}' not existing: {error}")
# TODO: Revisit when reconnection strategy is implemented
# Reconnect publisher, i.e., restore the connection - not efficient, but works
self.connect_rabbitmq_publisher(enable_acks=True)
return False

def query_ocrd_tool_json_from_server(self, processor_server_url: str):
# Request the ocrd tool json from the Processor Server
try:
response = requests_get(
urljoin(base=processor_server_url, url="info"),
headers={"Content-Type": "application/json"}
)
if response.status_code != 200:
message = f"Failed to retrieve tool json from: {processor_server_url}, code: {response.status_code}"
raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message)
return response.json()
except Exception as error:
message = f"Failed to retrieve ocrd tool json from: {processor_server_url}"
raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message, error)
def query_ocrd_tool_json_from_server(self, processor_name: str) -> Dict:
processor_server_base_url = self.deployer.resolve_processor_server_url(processor_name)
if processor_server_base_url == '':
message = f"Processor Server URL of '{processor_name}' not found"
raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, message=message)
return request_processor_server_tool_json(self.log, processor_server_base_url=processor_server_base_url)

async def get_network_agent_ocrd_tool(
self, processor_name: str, agent_type: AgentType = AgentType.PROCESSING_WORKER
) -> Dict:
ocrd_tool = {}
error_message = f"Network agent of type '{agent_type}' for processor '{processor_name}' not found."
if agent_type == AgentType.PROCESSING_WORKER:
ocrd_tool = self.ocrd_all_tool_json.get(processor_name, None)
elif agent_type == AgentType.PROCESSOR_SERVER:
processor_server_url = self.deployer.resolve_processor_server_url(processor_name)
if processor_server_url == '':
raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, error_message)
ocrd_tool = self.query_ocrd_tool_json_from_server(processor_server_url)
else:
if agent_type != AgentType.PROCESSING_WORKER and agent_type != AgentType.PROCESSOR_SERVER:
message = f"Unknown agent type: {agent_type}, {type(agent_type)}"
raise_http_exception(self.log, status_code=status.HTTP_501_NOT_IMPLEMENTED, message=message)
if agent_type == AgentType.PROCESSING_WORKER:
ocrd_tool = self.ocrd_all_tool_json.get(processor_name, None)
if agent_type == AgentType.PROCESSOR_SERVER:
ocrd_tool = self.query_ocrd_tool_json_from_server(processor_name)
if not ocrd_tool:
raise_http_exception(self.log, status.HTTP_404_NOT_FOUND, error_message)
return ocrd_tool
Expand All @@ -406,7 +346,7 @@ def network_agent_exists_worker(self, processor_name: str) -> bool:
# is needed on the Processing Server side
if processor_name == 'ocrd-dummy':
return True
return bool(self.check_if_queue_exists(processor_name=processor_name))
return bool(check_if_queue_exists(self.log, self.rmq_data, processor_name=processor_name))

def validate_agent_type_and_existence(self, processor_name: str, agent_type: AgentType) -> None:
agent_exists = False
Expand Down Expand Up @@ -513,64 +453,41 @@ async def validate_and_forward_job_to_network_agent(self, processor_name: str, d
return job_output

async def push_job_to_network_agent(self, data: PYJobInput, db_job: DBProcessorJob) -> PYJobOutput:
job_output = None
if data.agent_type == AgentType.PROCESSING_WORKER:
processing_message = create_processing_message(self.log, db_job)
self.log.debug(f"Pushing to processing worker: {data.processor_name}, {data.page_id}, {data.job_id}")
await self.push_job_to_processing_queue(data.processor_name, processing_message)
job_output = db_job.to_job_output()
elif data.agent_type == AgentType.PROCESSOR_SERVER:
self.log.debug(f"Pushing to processor server: {data.processor_name}, {data.page_id}, {data.job_id}")
job_output = await self.push_job_to_processor_server(data.processor_name, data)
else:
if data.agent_type != AgentType.PROCESSING_WORKER and data.agent_type != AgentType.PROCESSOR_SERVER:
message = f"Unknown agent type: {data.agent_type}, {type(data.agent_type)}"
raise_http_exception(self.log, status_code=status.HTTP_501_NOT_IMPLEMENTED, message=message)
job_output = None
self.log.debug(f"Pushing to {data.agent_type}: {data.processor_name}, {data.page_id}, {data.job_id}")
if data.agent_type == AgentType.PROCESSING_WORKER:
job_output = await self.push_job_to_processing_queue(db_job=db_job)
if data.agent_type == AgentType.PROCESSOR_SERVER:
job_output = await self.push_job_to_processor_server(job_input=data)
if not job_output:
message = f"Failed to create job output for job input: {data}"
raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message)
return job_output

async def push_job_to_processing_queue(self, processor_name: str, processing_message: OcrdProcessingMessage):
async def push_job_to_processing_queue(self, db_job: DBProcessorJob) -> PYJobOutput:
if not self.rmq_publisher:
message = "The Processing Server has no connection to RabbitMQ Server. RMQPublisher is not connected."
raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message)
processing_message = create_processing_message(self.log, db_job)
try:
encoded_message = OcrdProcessingMessage.encode_yml(processing_message)
self.rmq_publisher.publish_to_queue(queue_name=processor_name, message=encoded_message)
self.rmq_publisher.publish_to_queue(queue_name=db_job.processor_name, message=encoded_message)
except Exception as error:
message = (
f"Processing server has failed to push processing message to queue: {processor_name}, "
f"Processing server has failed to push processing message to queue: {db_job.processor_name}, "
f"Processing message: {processing_message.__dict__}"
)
raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message, error)
return db_job.to_job_output()

async def push_job_to_processor_server(self, processor_name: str, job_input: PYJobInput) -> PYJobOutput:
try:
json_data = dumps(job_input.dict(exclude_unset=True, exclude_none=True))
except Exception as error:
message = f"Failed to json dump the PYJobInput: {job_input}"
raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message, error)

processor_server_url = self.deployer.resolve_processor_server_url(processor_name)

# TODO: The amount of pages should come as a request input
# TODO: cf https://github.com/OCR-D/core/pull/1030/files#r1152551161
# currently, use 200 as a default
request_timeout = calculate_processing_request_timeout(amount_pages=200, timeout_per_page=20.0)

# Post a processing job to the Processor Server asynchronously
async with AsyncClient(timeout=Timeout(timeout=request_timeout, connect=30.0)) as client:
response = await client.post(
urljoin(base=processor_server_url, url="run"),
headers={"Content-Type": "application/json"},
json=loads(json_data)
)

if response.status_code != 202:
message = f"Failed to post '{processor_name}' job to: {processor_server_url}"
raise_http_exception(self.log, status.HTTP_500_INTERNAL_SERVER_ERROR, message)
job_output = response.json()
return job_output
async def push_job_to_processor_server(self, job_input: PYJobInput) -> PYJobOutput:
processor_server_base_url = self.deployer.resolve_processor_server_url(job_input.processor_name)
return await forward_job_to_processor_server(
self.log, job_input=job_input, processor_server_base_url=processor_server_base_url
)

async def get_processor_job(self, job_id: str) -> PYJobOutput:
return await _get_processor_job(self.log, job_id)
Expand Down
Loading

0 comments on commit 0cb258e

Please sign in to comment.