"""
# METS server functionality
"""
import re
from os import environ, _exit, chmod
from io import BytesIO
from typing import Any, Dict, Optional, Union, List, Tuple
from pathlib import Path
from urllib.parse import urlparse
import socket
from fastapi import FastAPI, Request, File, Form, Response
from fastapi.responses import JSONResponse
from requests import request, Session as requests_session
from requests.exceptions import ConnectionError
from requests_unixsocket import Session as requests_unixsocket_session
from pydantic import BaseModel, Field, ValidationError
import uvicorn
from ocrd_models import OcrdMets, OcrdFile, ClientSideOcrdFile, OcrdAgent, ClientSideOcrdAgent
from ocrd_utils import initLogging, getLogger, deprecated_alias
#
# Models
#
[docs]class OcrdFileModel(BaseModel):
file_grp : str = Field()
file_id : str = Field()
mimetype : str = Field()
page_id : Optional[str] = Field()
url : Optional[str] = Field()
local_filename : Optional[str] = Field()
[docs] @staticmethod
def create(file_grp : str, file_id : str, page_id : Optional[str], url : Optional[str], local_filename : Optional[Union[str, Path]], mimetype : str):
return OcrdFileModel(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=str(local_filename))
[docs]class OcrdAgentModel(BaseModel):
name : str = Field()
type : str = Field()
role : str = Field()
otherrole : Optional[str] = Field()
othertype : str = Field()
notes : Optional[List[Tuple[Dict[str, str], Optional[str]]]] = Field()
[docs] @staticmethod
def create(name : str, _type : str, role : str, otherrole : str, othertype : str, notes : List[Tuple[Dict[str, str], Optional[str]]]):
return OcrdAgentModel(name=name, type=_type, role=role, otherrole=otherrole, othertype=othertype, notes=notes)
[docs]class OcrdFileListModel(BaseModel):
files : List[OcrdFileModel] = Field()
[docs] @staticmethod
def create(files : List[OcrdFile]):
ret = OcrdFileListModel(
files=[OcrdFileModel.create(
file_grp=f.fileGrp,
file_id=f.ID,
mimetype=f.mimetype,
page_id=f.pageId,
url=f.url,
local_filename=f.local_filename
) for f in files])
return ret
[docs]class OcrdFileGroupListModel(BaseModel):
file_groups : List[str] = Field()
[docs] @staticmethod
def create(file_groups : List[str]):
return OcrdFileGroupListModel(file_groups=file_groups)
[docs]class OcrdAgentListModel(BaseModel):
agents : List[OcrdAgentModel] = Field()
[docs] @staticmethod
def create(agents : List[OcrdAgent]):
return OcrdAgentListModel(
agents=[OcrdAgentModel.create(name=a.name, _type=a.type, role=a.role, otherrole=a.otherrole, othertype=a.othertype, notes=a.notes) for a in agents]
)
#
# Client
#
[docs]class ClientSideOcrdMets():
"""
Partial substitute for :py:class:`ocrd_models.ocrd_mets.OcrdMets` which provides for
:py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_files`,
:py:meth:`ocrd_models.ocrd_mets.OcrdMets.find_all_files`, and
:py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_agent`,
:py:meth:`ocrd_models.ocrd_mets.OcrdMets.agents`,
:py:meth:`ocrd_models.ocrd_mets.OcrdMets.add_file` to query via HTTP a
:py:class:`ocrd.mets_server.OcrdMetsServer`.
"""
def __init__(self, url):
protocol = 'tcp' if url.startswith('http://') else 'uds'
self.log = getLogger(f'ocrd.mets_client[{url}]')
self.url = url if protocol == 'tcp' else f'http+unix://{url.replace("/", "%2F")}'
self.session = requests_session() if protocol == 'tcp' else requests_unixsocket_session()
def __getattr__(self, name):
raise NotImplementedError(f"ClientSideOcrdMets has no access to '{name}' - try without METS server")
def __str__(self):
return f'<ClientSideOcrdMets[url={self.url}]>'
@property
def workspace_path(self):
return self.session.request('GET', f'{self.url}/workspace_path').text
[docs] def reload(self):
return self.session.request('POST', f'{self.url}/reload').text
[docs] @deprecated_alias(ID="file_id")
@deprecated_alias(pageId="page_id")
@deprecated_alias(fileGrp="file_grp")
def find_files(self, **kwargs):
self.log.debug('find_files(%s)', kwargs)
if 'pageId' in kwargs:
kwargs['page_id'] = kwargs.pop('pageId')
if 'ID' in kwargs:
kwargs['file_id'] = kwargs.pop('ID')
if 'fileGrp' in kwargs:
kwargs['file_grp'] = kwargs.pop('fileGrp')
r = self.session.request('GET', f'{self.url}/file', params={**kwargs})
for f in r.json()['files']:
yield ClientSideOcrdFile(None, ID=f['file_id'], pageId=f['page_id'], fileGrp=f['file_grp'], url=f['url'], local_filename=f['local_filename'], mimetype=f['mimetype'])
[docs] def find_all_files(self, *args, **kwargs):
return list(self.find_files(*args, **kwargs))
[docs] def add_agent(self, *args, **kwargs):
return self.session.request('POST', f'{self.url}/agent', json=OcrdAgentModel.create(**kwargs).dict())
@property
def agents(self):
agent_dicts = self.session.request('GET', f'{self.url}/agent').json()['agents']
for agent_dict in agent_dicts:
agent_dict['_type'] = agent_dict.pop('type')
return [ClientSideOcrdAgent(None, **agent_dict) for agent_dict in agent_dicts]
@property
def unique_identifier(self):
return self.session.request('GET', f'{self.url}/unique_identifier').text
@property
def file_groups(self):
return self.session.request('GET', f'{self.url}/file_groups').json()['file_groups']
[docs] @deprecated_alias(pageId="page_id")
@deprecated_alias(ID="file_id")
def add_file(self, file_grp, content=None, file_id=None, url=None, local_filename=None, mimetype=None, page_id=None, **kwargs):
data = OcrdFileModel.create(
file_id=file_id,
file_grp=file_grp,
page_id=page_id,
mimetype=mimetype,
url=url,
local_filename=local_filename)
r = self.session.request('POST', f'{self.url}/file', data=data.dict())
return ClientSideOcrdFile(
None,
ID=file_id,
fileGrp=file_grp,
url=url,
pageId=page_id,
mimetype=mimetype,
local_filename=local_filename)
[docs] def save(self):
self.session.request('PUT', self.url)
[docs] def stop(self):
try:
self.session.request('DELETE', self.url)
except ConnectionError:
# Expected because we exit the process without returning
pass
#
# Server
#
[docs]class OcrdMetsServer():
def __init__(self, workspace, url):
self.workspace = workspace
self.url = url
self.is_uds = not (url.startswith('http://') or url.startswith('https://'))
self.log = getLogger(f'ocrd.mets_server[{self.url}]')
[docs] def shutdown(self):
self.log.info("Shutting down METS server")
if self.is_uds:
Path(self.url).unlink()
# os._exit because uvicorn catches SystemExit raised by sys.exit
_exit(0)
[docs] def startup(self):
self.log.info("Starting up METS server")
workspace = self.workspace
app = FastAPI(
title="OCR-D METS Server",
description="Providing simultaneous write-access to mets.xml for OCR-D",
)
@app.exception_handler(ValidationError)
async def exception_handler_validation_error(request: Request, exc: ValidationError):
return JSONResponse(status_code=400, content=exc.errors())
@app.exception_handler(FileExistsError)
async def exception_handler_file_exists(request: Request, exc: FileExistsError):
return JSONResponse(status_code=400, content=str(exc))
@app.exception_handler(re.error)
async def exception_handler_invalid_regex(request: Request, exc: re.error):
return JSONResponse(status_code=400, content=f'invalid regex: {exc}')
@app.get("/file", response_model=OcrdFileListModel)
async def find_files(
file_grp : Optional[str] = None,
file_id : Optional[str] = None,
page_id : Optional[str] = None,
mimetype : Optional[str] = None,
local_filename : Optional[str] = None,
url : Optional[str] = None,
):
"""
Find files in the mets
"""
found = workspace.mets.find_all_files(fileGrp=file_grp, ID=file_id, pageId=page_id, mimetype=mimetype, local_filename=local_filename, url=url)
return OcrdFileListModel.create(found)
@app.put('/')
def save():
return workspace.save_mets()
@app.post('/file', response_model=OcrdFileModel)
async def add_file(
file_grp : str = Form(),
file_id : str = Form(),
page_id : Optional[str] = Form(),
mimetype : str = Form(),
url : Optional[str] = Form(None),
local_filename : Optional[str] = Form(None),
):
"""
Add a file
"""
# Validate
file_resource = OcrdFileModel.create(file_grp=file_grp, file_id=file_id, page_id=page_id, mimetype=mimetype, url=url, local_filename=local_filename)
# Add to workspace
kwargs = file_resource.dict()
workspace.add_file(**kwargs)
return file_resource
@app.get('/file_groups', response_model=OcrdFileGroupListModel)
async def file_groups():
return {'file_groups': workspace.mets.file_groups}
@app.post('/agent', response_model=OcrdAgentModel)
async def add_agent(agent : OcrdAgentModel):
kwargs = agent.dict()
kwargs['_type'] = kwargs.pop('type')
workspace.mets.add_agent(**kwargs)
return agent
@app.get('/agent', response_model=OcrdAgentListModel)
async def agents():
return OcrdAgentListModel.create(workspace.mets.agents)
@app.get('/unique_identifier', response_model=str)
async def unique_identifier():
return Response(content=workspace.mets.unique_identifier, media_type='text/plain')
@app.get('/workspace_path', response_model=str)
async def workspace_path():
return Response(content=workspace.directory, media_type="text/plain")
@app.post('/reload')
async def workspace_reload_mets():
workspace.reload_mets()
return Response(content=f'Reloaded from {workspace.directory}', media_type="text/plain")
@app.delete('/')
async def stop():
"""
Stop the server
"""
getLogger('ocrd.models.ocrd_mets').info('Shutting down')
workspace.save_mets()
self.shutdown()
# ------------- #
if self.is_uds:
# Create socket and change to world-readable and -writable to avoid
# permsission errors
self.log.debug(f"chmod 0o677 {self.url}")
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(self.url) # creates the socket file
server.close()
chmod(self.url, 0o666)
uvicorn_kwargs = {'uds': self.url}
else:
parsed = urlparse(self.url)
uvicorn_kwargs = {'host': parsed.hostname, 'port': parsed.port}
self.log.debug("Starting uvicorn")
uvicorn.run(app, **uvicorn_kwargs)