"""
Abstraction for the Processing Server unit in this arch:
https://user-images.githubusercontent.com/7795705/203554094-62ce135a-b367-49ba-9960-ffe1b7d39b2c.jpg
Calls to native OCR-D processor should happen through
the Processing Worker wrapper to hide low level details.
According to the current requirements, each ProcessingWorker
is a single OCR-D Processor instance.
"""
from datetime import datetime
from logging import FileHandler, Formatter
from os import getpid
from time import sleep
import pika.spec
import pika.adapters.blocking_connection
from pika.exceptions import AMQPConnectionError
from ocrd_utils import config, getLogger, LOG_FORMAT
from .database import (
sync_initiate_database,
sync_db_get_workspace,
sync_db_update_processing_job,
)
from .logging import (
get_processing_job_logging_file_path,
get_processing_worker_logging_file_path
)
from .models import StateEnum
from .process_helpers import invoke_processor
from .rabbitmq_utils import (
OcrdProcessingMessage,
OcrdResultMessage,
RMQConsumer,
RMQPublisher
)
from .utils import (
calculate_execution_time,
post_to_callback_url,
verify_database_uri,
verify_and_parse_mq_uri
)
[docs]class ProcessingWorker:
def __init__(self, rabbitmq_addr, mongodb_addr, processor_name, ocrd_tool: dict, processor_class=None) -> None:
self.log = getLogger(f'ocrd_network.processing_worker')
log_file = get_processing_worker_logging_file_path(processor_name=processor_name, pid=getpid())
file_handler = FileHandler(filename=log_file, mode='a')
file_handler.setFormatter(Formatter(LOG_FORMAT))
self.log.addHandler(file_handler)
try:
verify_database_uri(mongodb_addr)
self.log.debug(f'Verified MongoDB URL: {mongodb_addr}')
rmq_data = verify_and_parse_mq_uri(rabbitmq_addr)
self.rmq_username = rmq_data['username']
self.rmq_password = rmq_data['password']
self.rmq_host = rmq_data['host']
self.rmq_port = rmq_data['port']
self.rmq_vhost = rmq_data['vhost']
self.log.debug(f'Verified RabbitMQ Credentials: {self.rmq_username}:{self.rmq_password}')
self.log.debug(f'Verified RabbitMQ Server URL: {self.rmq_host}:{self.rmq_port}{self.rmq_vhost}')
except ValueError as e:
raise ValueError(e)
sync_initiate_database(mongodb_addr) # Database client
self.ocrd_tool = ocrd_tool
# The str name of the OCR-D processor instance to be started
self.processor_name = processor_name
# The processor class to be used to instantiate the processor
# Think of this as a func pointer to the constructor of the respective OCR-D processor
self.processor_class = processor_class
# Gets assigned when `connect_consumer` is called on the worker object
# Used to consume OcrdProcessingMessage from the queue with name {processor_name}
self.rmq_consumer = None
# Gets assigned when the `connect_publisher` is called on the worker object
# The publisher is connected when the `result_queue` field of the OcrdProcessingMessage is set for first time
# Used to publish OcrdResultMessage type message to the queue with name {processor_name}-result
self.rmq_publisher = None
# Always create a queue (idempotent)
self.create_queue()
[docs] def connect_consumer(self) -> None:
self.log.info(f'Connecting RMQConsumer to RabbitMQ server: '
f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}')
self.rmq_consumer = RMQConsumer(
host=self.rmq_host,
port=self.rmq_port,
vhost=self.rmq_vhost
)
self.log.debug(f'RMQConsumer authenticates with username: '
f'{self.rmq_username}, password: {self.rmq_password}')
self.rmq_consumer.authenticate_and_connect(
username=self.rmq_username,
password=self.rmq_password
)
self.log.info(f'Successfully connected RMQConsumer.')
[docs] def connect_publisher(self, enable_acks: bool = True) -> None:
self.log.info(f'Connecting RMQPublisher to RabbitMQ server: '
f'{self.rmq_host}:{self.rmq_port}{self.rmq_vhost}')
self.rmq_publisher = RMQPublisher(
host=self.rmq_host,
port=self.rmq_port,
vhost=self.rmq_vhost
)
self.log.debug(f'RMQPublisher authenticates with username: '
f'{self.rmq_username}, password: {self.rmq_password}')
self.rmq_publisher.authenticate_and_connect(
username=self.rmq_username,
password=self.rmq_password
)
if enable_acks:
self.rmq_publisher.enable_delivery_confirmations()
self.log.info('Delivery confirmations are enabled')
self.log.info('Successfully connected RMQPublisher.')
# Define what happens every time a message is consumed
# from the queue with name self.processor_name
[docs] def on_consumed_message(
self,
channel: pika.adapters.blocking_connection.BlockingChannel,
delivery: pika.spec.Basic.Deliver,
properties: pika.spec.BasicProperties,
body: bytes) -> None:
consumer_tag = delivery.consumer_tag
delivery_tag: int = delivery.delivery_tag
is_redelivered: bool = delivery.redelivered
message_headers: dict = properties.headers
self.log.debug(f'Consumer tag: {consumer_tag}, '
f'message delivery tag: {delivery_tag}, '
f'redelivered: {is_redelivered}')
self.log.debug(f'Message headers: {message_headers}')
try:
self.log.debug(f'Trying to decode processing message with tag: {delivery_tag}')
processing_message: OcrdProcessingMessage = OcrdProcessingMessage.decode_yml(body)
except Exception as e:
self.log.error(f'Failed to decode processing message body: {body}')
self.log.error(f'Nacking processing message with tag: {delivery_tag}')
channel.basic_nack(delivery_tag=delivery_tag, multiple=False, requeue=False)
raise Exception(f'Failed to decode processing message with tag: {delivery_tag}, reason: {e}')
try:
self.log.info(f'Starting to process the received message: {processing_message.__dict__}')
self.process_message(processing_message=processing_message)
except Exception as e:
self.log.error(f'Failed to process processing message with tag: {delivery_tag}')
self.log.error(f'Nacking processing message with tag: {delivery_tag}')
channel.basic_nack(delivery_tag=delivery_tag, multiple=False, requeue=False)
raise Exception(f'Failed to process processing message with tag: {delivery_tag}, reason: {e}')
self.log.info(f'Successfully processed RabbitMQ message')
self.log.debug(f'Acking message with tag: {delivery_tag}')
channel.basic_ack(delivery_tag=delivery_tag, multiple=False)
[docs] def start_consuming(self) -> None:
if self.rmq_consumer:
self.log.info(f'Configuring consuming from queue: {self.processor_name}')
self.rmq_consumer.configure_consuming(
queue_name=self.processor_name,
callback_method=self.on_consumed_message
)
self.log.info(f'Starting consuming from queue: {self.processor_name}')
# Starting consuming is a blocking action
self.rmq_consumer.start_consuming()
else:
raise Exception('The RMQConsumer is not connected/configured properly')
# TODO: Better error handling required to catch exceptions
[docs] def process_message(self, processing_message: OcrdProcessingMessage) -> None:
# Verify that the processor name in the processing message
# matches the processor name of the current processing worker
if self.processor_name != processing_message.processor_name:
raise ValueError(f'Processor name is not matching. Expected: {self.processor_name},'
f'Got: {processing_message.processor_name}')
# All of this is needed because the OcrdProcessingMessage object
# may not contain certain keys. Simply passing None in the OcrdProcessingMessage constructor
# breaks the message validator schema which expects String, but not None due to the Optional[] wrapper.
pm_keys = processing_message.__dict__.keys()
job_id = processing_message.job_id
input_file_grps = processing_message.input_file_grps
output_file_grps = processing_message.output_file_grps if 'output_file_grps' in pm_keys else None
path_to_mets = processing_message.path_to_mets if 'path_to_mets' in pm_keys else None
workspace_id = processing_message.workspace_id if 'workspace_id' in pm_keys else None
page_id = processing_message.page_id if 'page_id' in pm_keys else None
result_queue_name = processing_message.result_queue_name if 'result_queue_name' in pm_keys else None
callback_url = processing_message.callback_url if 'callback_url' in pm_keys else None
internal_callback_url = processing_message.internal_callback_url if 'internal_callback_url' in pm_keys else None
parameters = processing_message.parameters if processing_message.parameters else {}
if not path_to_mets and not workspace_id:
raise ValueError(f'`path_to_mets` nor `workspace_id` was set in the ocrd processing message')
if path_to_mets:
mets_server_url = sync_db_get_workspace(workspace_mets_path=path_to_mets).mets_server_url
if not path_to_mets and workspace_id:
path_to_mets = sync_db_get_workspace(workspace_id).workspace_mets_path
mets_server_url = sync_db_get_workspace(workspace_id).mets_server_url
execution_failed = False
self.log.debug(f'Invoking processor: {self.processor_name}')
start_time = datetime.now()
job_log_file = get_processing_job_logging_file_path(job_id=job_id)
sync_db_update_processing_job(
job_id=job_id,
state=StateEnum.running,
path_to_mets=path_to_mets,
start_time=start_time,
log_file_path=job_log_file
)
try:
invoke_processor(
processor_class=self.processor_class,
executable=self.processor_name,
abs_path_to_mets=path_to_mets,
input_file_grps=input_file_grps,
output_file_grps=output_file_grps,
page_id=page_id,
log_filename=job_log_file,
parameters=processing_message.parameters,
mets_server_url=mets_server_url
)
except Exception as error:
self.log.debug(f"processor_name: {self.processor_name}, path_to_mets: {path_to_mets}, "
f"input_grps: {input_file_grps}, output_file_grps: {output_file_grps}, "
f"page_id: {page_id}, parameters: {parameters}")
self.log.exception(error)
execution_failed = True
end_time = datetime.now()
exec_duration = calculate_execution_time(start_time, end_time)
job_state = StateEnum.success if not execution_failed else StateEnum.failed
sync_db_update_processing_job(
job_id=job_id,
state=job_state,
end_time=end_time,
exec_time=f'{exec_duration} ms'
)
result_message = OcrdResultMessage(
job_id=job_id,
state=job_state.value,
path_to_mets=path_to_mets,
# May not be always available
workspace_id=workspace_id
)
self.log.info(f'Result message: {result_message.__dict__}')
# If the result_queue field is set, send the result message to a result queue
if result_queue_name:
self.publish_to_result_queue(result_queue_name, result_message)
if callback_url:
# If the callback_url field is set,
# post the result message (callback to a user defined endpoint)
post_to_callback_url(self.log, callback_url, result_message)
if internal_callback_url:
# If the internal callback_url field is set,
# post the result message (callback to Processing Server endpoint)
post_to_callback_url(self.log, internal_callback_url, result_message)
[docs] def publish_to_result_queue(self, result_queue: str, result_message: OcrdResultMessage):
if self.rmq_publisher is None:
self.connect_publisher()
# create_queue method is idempotent - nothing happens if
# a queue with the specified name already exists
self.rmq_publisher.create_queue(queue_name=result_queue)
self.log.info(f'Publishing result message to queue: {result_queue}')
encoded_result_message = OcrdResultMessage.encode_yml(result_message)
self.rmq_publisher.publish_to_queue(
queue_name=result_queue,
message=encoded_result_message
)
[docs] def create_queue(
self,
connection_attempts: int = config.OCRD_NETWORK_WORKER_QUEUE_CONNECT_ATTEMPTS,
retry_delay: int = 3) -> None:
"""Create the queue for this worker
Originally only the processing-server created the queues for the workers according to the
configuration file. This is intended to make external deployment of workers possible.
"""
if self.rmq_publisher is None:
attempts_left = connection_attempts if connection_attempts > 0 else 1
while attempts_left > 0:
try:
self.connect_publisher()
break
except AMQPConnectionError as e:
if attempts_left <= 1:
raise e
attempts_left -= 1
sleep(retry_delay)
# the following function is idempotent
self.rmq_publisher.create_queue(queue_name=self.processor_name)