diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eb2fc983..e609452d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,9 +52,6 @@ jobs: python -m build --sdist --wheel --outdir dist/ . - name: Publish a Python distribution to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} - name: Upload PyPI artifacts to GH storage uses: actions/upload-artifact@v3 with: diff --git a/environment.yml b/environment.yml index 6d192901..6e4ca53d 100644 --- a/environment.yml +++ b/environment.yml @@ -32,4 +32,6 @@ dependencies: - setuptools>=36.6.0,<70.0.0 - pip: - coremltools~=8.1 + - htrmopo + - platformdirs - file:. diff --git a/environment_cuda.yml b/environment_cuda.yml index d9525927..d001181c 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -33,4 +33,6 @@ dependencies: - setuptools>=36.6.0,<70.0.0 - pip: - coremltools~=8.1 + - htrmopo + - platformdirs - file:. diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index fe67bf80..0eb52936 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -18,94 +18,197 @@ Command line driver for publishing models to the model repository. """ +import re import logging -import os import click +from pathlib import Path +from difflib import get_close_matches + from .util import message logging.captureWarnings(True) logger = logging.getLogger('kraken') +def _validate_script(script: str) -> str: + from htrmopo.util import _iso15924 + if script not in _iso15924: + return get_close_matches(script, _iso15924.keys()) + return script + + +def _validate_language(language: str) -> str: + from htrmopo.util import _iso639_3 + if language not in _iso639_3: + return get_close_matches(language, _iso639_3.keys()) + return language + + +def _validate_license(license: str) -> str: + from htrmopo.util import _licenses + if license not in _licenses: + return get_close_matches(license, _licenses.keys()) + return license + + +def _get_field_list(name, + validation_fn=lambda x: x, + required: bool = False): + values = [] + while True: + value = click.prompt(name, default='') + if value: + if (cand := validation_fn(value)) == value: + values.append(value) + else: + message(f'Not a valid {name} value. Did you mean {cand}?') + else: + if click.confirm(f'All `{name}` values added?'): + if required and not values: + message(f'`{name}` is a required field.') + continue + else: + break + else: + continue + return values + + @click.command('publish') @click.pass_context @click.option('-i', '--metadata', show_default=True, - type=click.File(mode='r', lazy=True), help='Metadata for the ' - 'model. Will be prompted from the user if not given') + type=click.File(mode='r', lazy=True), help='Model card file for the model.') @click.option('-a', '--access-token', prompt=True, help='Zenodo access token') +@click.option('-d', '--doi', help='DOI of an existing record to update') @click.option('-p', '--private/--public', default=False, help='Disables Zenodo ' 'community inclusion request. Allows upload of models that will not show ' 'up on `kraken list` output') @click.argument('model', nargs=1, type=click.Path(exists=False, readable=True, dir_okay=False)) -def publish(ctx, metadata, access_token, private, model): +def publish(ctx, metadata, access_token, doi, private, model): """ Publishes a model on the zenodo model repository. """ import json + import yaml + import tempfile - from importlib import resources - from jsonschema import validate - from jsonschema.exceptions import ValidationError + from htrmopo import publish_model, update_model - from kraken import repo - from kraken.lib import models + from kraken.lib.vgsl import TorchVGSLModel from kraken.lib.progress import KrakenDownloadProgressBar - ref = resources.files('kraken').joinpath('metadata.schema.json') - with open(ref, 'rb') as fp: - schema = json.load(fp) - - nn = models.load_any(model) - - if not metadata: - author = click.prompt('author') - affiliation = click.prompt('affiliation') - summary = click.prompt('summary') - description = click.edit('Write long form description (training data, transcription standards) of the model here') - accuracy_default = None - # take last accuracy measurement in model metadata - if 'accuracy' in nn.nn.user_metadata and nn.nn.user_metadata['accuracy']: - accuracy_default = nn.nn.user_metadata['accuracy'][-1][1] * 100 - accuracy = click.prompt('accuracy on test set', type=float, default=accuracy_default) - script = [ - click.prompt( - 'script', - type=click.Choice( - sorted( - schema['properties']['script']['items']['enum'])), - show_choices=True)] - license = click.prompt( - 'license', - type=click.Choice( - sorted( - schema['properties']['license']['enum'])), - show_choices=True) - metadata = { - 'authors': [{'name': author, 'affiliation': affiliation}], - 'summary': summary, - 'description': description, - 'accuracy': accuracy, - 'license': license, - 'script': script, - 'name': os.path.basename(model), - 'graphemes': ['a'] - } - while True: - try: - validate(metadata, schema) - except ValidationError as e: - message(e.message) - metadata[e.path[-1]] = click.prompt(e.path[-1], type=float if e.schema['type'] == 'number' else str) - continue - break + pub_fn = publish_model + + _yaml_delim = r'(?:---|\+\+\+)' + _yaml = r'(.*?)' + _content = r'\s*(.+)$' + _re_pattern = r'^\s*' + _yaml_delim + _yaml + _yaml_delim + _content + _yaml_regex = re.compile(_re_pattern, re.S | re.M) + + nn = TorchVGSLModel.load_model(model) + frontmatter = {} + # construct metadata if none is given + if metadata: + frontmatter, content = _yaml_regex.match(metadata.read()).groups() + frontmatter = yaml.safe_load(frontmatter) else: - metadata = json.load(metadata) - validate(metadata, schema) - metadata['graphemes'] = [char for char in ''.join(nn.codec.c2l.keys())] - with KrakenDownloadProgressBar() as progress: + frontmatter['summary'] = click.prompt('summary') + content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here') + + creators = [] + message('To stop adding authors, leave the author name field empty.') + while True: + author = click.prompt('author name', default='') + if author: + creators.append({'name': author}) + else: + if click.confirm('All authors added?'): + break + else: + continue + affiliation = click.prompt('affiliation', default='') + orcid = click.prompt('orcid', default='') + if affiliation is not None: + creators[-1]['affiliation'] = affiliation + if orcid is not None: + creators[-1]['orcid'] = orcid + if not creators: + raise click.UsageError('The `authors` field is obligatory. Aborting') + + frontmatter['authors'] = creators + while True: + license = click.prompt('license') + if (lic := _validate_license(license)) == license: + frontmatter['license'] = license + break + else: + message(f'Not a valid license identifer. Did you mean {lic}?') + + message('To stop adding values to the following fields, enter an empty field.') + + frontmatter['language'] = _get_field_list('language', _validate_language, required=True) + frontmatter['script'] = _get_field_list('script', _validate_script, required=True) + + if len(tags := _get_field_list('tag')): + frontmatter['tags'] = tags + ['kraken_pytorch'] + if len(datasets := _get_field_list('dataset URL')): + frontmatter['datasets'] = datasets + if len(base_model := _get_field_list('base model URL')): + frontmatter['base_model'] = base_model + + software_hints = ['kind=vgsl'] + + # take last metrics field, falling back to accuracy field in model metadata + if nn.model_type == 'recognition': + metrics = {} + if len(nn.user_metadata.get('metrics', '')): + if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None: + metrics['cer'] = 100 - (val_accuracy * 100) + if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None: + metrics['wer'] = 100 - (val_word_accuracy * 100) + elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None: + metrics['cer'] = 100 - accuracy + frontmatter['metrics'] = metrics + + # some recognition-specific software hints and metrics + software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) + + frontmatter['software_hints'] = software_hints + + frontmatter['software_name'] = 'kraken' + frontmatter['model_type'] = [nn.model_type] + + # build temporary directory + with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress: upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False) - oid = repo.publish_model(model, metadata, access_token, lambda total, advance: progress.update(upload_task, total=total, advance=advance), private) - message('model PID: {}'.format(oid)) + + model = Path(model).resolve() + tmpdir = Path(tmpdir) + (tmpdir / model.name).resolve().symlink_to(model) + if nn.model_type == 'recognition': + # v0 metadata only supports recognition models + v0_metadata = { + 'summary': frontmatter['summary'], + 'description': content, + 'license': frontmatter['license'], + 'script': frontmatter['script'], + 'name': model.name, + 'graphemes': [char for char in ''.join(nn.codec.c2l.keys())] + } + if frontmatter['metrics']: + v0_metadata['accuracy'] = 100 - metrics['cer'] + with open(tmpdir / 'metadata.json', 'w') as fo: + json.dump(v0_metadata, fo) + kwargs = {'model': tmpdir, + 'model_card': f'---\n{yaml.dump(frontmatter)}---\n{content}', + 'access_token': access_token, + 'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance), + 'private': private} + if doi: + pub_fn = update_model + kwargs['model_id'] = doi + oid = pub_fn(**kwargs) + message(f'model PID: {oid}') diff --git a/kraken/kraken.py b/kraken/kraken.py index 23a12daf..d1db25ea 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -18,20 +18,28 @@ Command line drivers for recognition functionality. """ -import dataclasses -import logging import os import uuid +import click import shlex +import logging import warnings -from functools import partial -from pathlib import Path -from typing import IO, Any, Callable, Dict, List, Union, cast +import dataclasses -import click from PIL import Image +from pathlib import Path +from itertools import chain +from functools import partial from importlib import resources +from platformdirs import user_data_dir +from typing import IO, Any, Callable, Dict, List, Union, cast + +from rich import print +from rich.tree import Tree +from rich.table import Table +from rich.console import Group from rich.traceback import install +from rich.markdown import Markdown from kraken.lib import log @@ -99,7 +107,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou processing_steps=ctx.meta['steps'])) else: form = None - ext = os.path.splitext(output)[1] + ext = Path(output).suffix if ext in ['.jpg', '.jpeg', '.JPG', '.JPEG', '']: form = 'png' if ext: @@ -351,7 +359,6 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty placing their respective outputs in temporary files. """ import glob - import os.path import tempfile from threadpoolctl import threadpool_limits @@ -365,9 +372,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty # expand batch inputs if batch_input and suffix: for batch_expr in batch_input: - for in_file in glob.glob(os.path.expanduser(batch_expr), recursive=True): - - input.append((in_file, '{}{}'.format(os.path.splitext(in_file)[0], suffix))) + for in_file in glob.glob(str(Path(batch_expr).expanduser()), recursive=True): + input.append(Path(in_file).with_suffix(suffix)) # parse pdfs if format_type == 'pdf': @@ -507,13 +513,14 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, logger.warning(f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.') boxes = False + model = [Path(m) for m in model] if boxes is False: if not model: model = [SEGMENTATION_DEFAULT_MODEL] ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()), category='processing', description='Baseline and region segmentation', - settings={'model': [os.path.basename(m) for m in model], + settings={'model': [m.name for m in model], 'text_direction': text_direction})) # first try to find the segmentation models by their given names, then @@ -521,15 +528,16 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, locations = [] for m in model: location = None - search = [m, os.path.join(click.get_app_dir(APP_NAME), m)] + search = chain([m], + Path(user_data_dir('htrmopo')).rglob(str(m)), + Path(click.get_app_dir('kraken')).rglob(str(m))) for loc in search: - if os.path.isfile(loc): + if loc.is_file(): location = loc locations.append(loc) break if not location: - raise click.BadParameter(f'No model for {m} found') - + raise click.BadParameter(f'No model for {str(m)} found') from kraken.lib.vgsl import TorchVGSLModel model = [] @@ -630,11 +638,12 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): nm: Dict[str, models.TorchSeqRecognizer] = {} ign_tags = model.pop('ignore') for k, v in model.items(): - search = [v, - os.path.join(click.get_app_dir(APP_NAME), v)] + search = chain([Path(v)], + Path(user_data_dir('htrmopo')).rglob(v), + Path(click.get_app_dir('kraken')).rglob(v)) location = None for loc in search: - if os.path.isfile(loc): + if loc.is_file(): location = loc break if not location: @@ -661,7 +670,7 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): category='processing', description='Text line recognition', settings={'text_direction': text_direction, - 'models': ' '.join(os.path.basename(v) for v in model.values()), + 'models': ' '.join(Path(v).name for v in model.values()), 'pad': pad, 'bidi_reordering': reorder})) @@ -677,29 +686,90 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): @cli.command('show') @click.pass_context +@click.option('-V', '--metadata-version', + default='highest', + type=click.Choice(['v0', 'v1', 'highest']), + help='Version of metadata to fetch if multiple exist in repository.') @click.argument('model_id') -def show(ctx, model_id): +def show(ctx, metadata_version, model_id): """ Retrieves model metadata from the repository. """ - from kraken import repo + from htrmopo.util import iso15924_to_name, iso639_3_to_name + from kraken.repo import get_description from kraken.lib.util import is_printable, make_printable - desc = repo.get_description(model_id) + def _render_creators(creators): + o = [] + for creator in creators: + c_text = creator['name'] + if (orcid := creator.get('orcid', None)) is not None: + c_text += f' ({orcid})' + if (affiliation := creator.get('affiliation', None)) is not None: + c_text += f' ({affiliation})' + o.append(c_text) + return o + + def _render_metrics(metrics): + if metrics: + return [f'{k}: {v:.2f}' for k, v in metrics.items()] + return '' + + if metadata_version == 'highest': + metadata_version = None - chars = [] - combining = [] - for char in sorted(desc['graphemes']): - if not is_printable(char): - combining.append(make_printable(char)) - else: - chars.append(char) - message( - 'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}'.format( - model_id, desc['summary'], desc['description'], ' '.join( - desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join( - x['name'] for x in desc['creators']), desc['publication_date'])) - ctx.exit(0) + try: + desc = get_description(model_id, + version=metadata_version, + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) + except ValueError as e: + logger.error(e) + ctx.exit(1) + + if desc.version == 'v0': + chars = [] + combining = [] + for char in sorted(desc.graphemes): + if not is_printable(char): + combining.append(make_printable(char)) + else: + chars.append(char) + + table = Table(title=desc.summary, show_header=False) + table.add_column('key', justify="left", no_wrap=True) + table.add_column('value', justify="left", no_wrap=False) + table.add_row('DOI', desc.doi) + table.add_row('concept DOI', desc.concept_doi) + table.add_row('publication date', desc.publication_date.isoformat()) + table.add_row('model type', Group(*desc.model_type)) + table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script])) + table.add_row('alphabet', Group(' '.join(chars), ', '.join(combining))) + table.add_row('keywords', Group(*desc.keywords)) + table.add_row('metrics', Group(*_render_metrics(desc.metrics))) + table.add_row('license', desc.license) + table.add_row('creators', Group(*_render_creators(desc.creators))) + table.add_row('description', desc.description) + elif desc.version == 'v1': + table = Table(title=desc.summary, show_header=False) + table.add_column('key', justify="left", no_wrap=True) + table.add_column('value', justify="left", no_wrap=False) + table.add_row('DOI', desc.doi) + table.add_row('concept DOI', desc.concept_doi) + table.add_row('publication date', desc.publication_date.isoformat()) + table.add_row('model type', Group(*desc.model_type)) + table.add_row('language', Group(*[iso639_3_to_name(x) for x in desc.language])) + table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script])) + table.add_row('keywords', Group(*desc.keywords) if desc.keywords else '') + table.add_row('datasets', Group(*desc.datasets) if desc.datasets else '') + table.add_row('metrics', Group(*_render_metrics(desc.metrics)) if desc.metrics else '') + table.add_row('base model', Group(*desc.base_model) if desc.base_model else '') + table.add_row('software', desc.software_name) + table.add_row('software_hints', Group(*desc.software_hints) if desc.software_hints else '') + table.add_row('license', desc.license) + table.add_row('creators', Group(*_render_creators(desc.creators))) + table.add_row('description', Markdown(desc.description)) + + print(table) @cli.command('list') @@ -708,15 +778,29 @@ def list_models(ctx): """ Lists models in the repository. """ - from kraken import repo + from kraken.repo import get_listing from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) - model_list = repo.get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance)) - for id, metadata in model_list.items(): - message('{} ({}) - {}'.format(id, ', '.join(metadata['type']), metadata['summary'])) - ctx.exit(0) + repository = get_listing(callback=lambda total, advance: progress.update(download_task, total=total, advance=advance), + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) + + table = Table(show_header=True) + table.add_column('DOI', justify="left", no_wrap=True) + table.add_column('summary', justify="left", no_wrap=False) + table.add_column('model type', justify="left", no_wrap=False) + table.add_column('keywords', justify="left", no_wrap=False) + + for k, records in repository.items(): + t = Tree(k) + [t.add(x.doi) for x in records] + table.add_row(t, + Group(*[''] + [x.summary for x in records]), + Group(*[''] + ['; '.join(x.model_type) for x in records]), + Group(*[''] + ['; '.join(x.keywords) for x in records])) + + print(table) @cli.command('get') @@ -726,20 +810,24 @@ def get(ctx, model_id): """ Retrieves a model from the repository. """ - from kraken import repo + from htrmopo import get_model + + from kraken.repo import get_description from kraken.lib.progress import KrakenDownloadProgressBar try: - os.makedirs(click.get_app_dir(APP_NAME)) - except OSError: - pass + get_description(model_id, + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) + except ValueError as e: + logger.error(e) + ctx.exit(1) with KrakenDownloadProgressBar() as progress: download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False) - filename = repo.get_model(model_id, click.get_app_dir(APP_NAME), - lambda total, advance: progress.update(download_task, total=total, advance=advance)) - message(f'Model name: {filename}') - ctx.exit(0) + model_dir = get_model(model_id, + callback=lambda total, advance: progress.update(download_task, total=total, advance=advance)) + model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iterdir())) + message(f'Model dir: {model_dir} (model files: {", ".join(x.name for x in model_candidates)})') if __name__ == '__main__': diff --git a/kraken/metadata.schema.json b/kraken/metadata.schema.json deleted file mode 100644 index a7483f69..00000000 --- a/kraken/metadata.schema.json +++ /dev/null @@ -1,103 +0,0 @@ - { - "definitions": {}, - "$schema": "http://json-schema.org/draft-07/schema#", - "$id": "http://example.com/root.json", - "type": "object", - "title": "The Root Schema", - "required": [ - "authors", - "summary", - "description", - "accuracy", - "license", - "script", - "name", - "graphemes" - ], - "properties": { - "authors": { - "$id": "#/properties/authors", - "type": "array", - "title": "Authors of the model", - "items": { - "$id": "#/properties/authors/items", - "type": "object", - "title": "items", - "required": [ - "name", - "affiliation" - ], - "properties": { - "name": { - "$id": "#/properties/authors/items/properties/name", - "type": "string", - "title": "A single author's name", - "pattern": "^(.*)$" - }, - "affiliation": { - "$id": "#/properties/authors/items/properties/affiliation", - "type": "string", - "title": "A single author's institutional affiliation", - "pattern": "^(.*)$" - } - } - } - }, - "summary": { - "$id": "#/properties/summary", - "type": "string", - "title": "A one-line summary of the model", - "pattern": "^(.*)$" - }, - "description": { - "$id": "#/properties/description", - "type": "string", - "title": "A long-form description of the model." - }, - "accuracy": { - "$id": "#/properties/accuracy", - "type": "number", - "title": "Test accuracy of the model", - "default": 0.0, - "minimum": 0.0, - "maximum": 100.0 - }, - "license": { - "$id": "#/properties/license", - "type": "string", - "title": "License of the model", - "default": "Apache-2.0", - "enum": ["AAL", "AFL-3.0", "AGPL-3.0", "APL-1.0", "APSL-2.0", "Against-DRM", "Apache-1.1", "Apache-2.0", "Artistic-2.0", "BSD-2-Clause", "BSD-3-Clause", "BSL-1.0", "BitTorrent-1.1", "CATOSL-1.1", "CC-BY-4.0", "CC-BY-NC-4.0", "CC-BY-SA-4.0", "CC0-1.0", "CDDL-1.0", "CECILL-2.1", "CNRI-Python", "CPAL-1.0", "CUA-OPL-1.0", "DSL", "ECL-2.0", "EFL-2.0", "EPL-1.0", "EPL-2.0", "EUDatagrid", "EUPL-1.1", "Entessa", "FAL-1.3", "Fair", "Frameworx-1.0", "GFDL-1.3-no-cover-texts-no-invariant-sections", "GPL-2.0", "GPL-3.0", "HPND", "IPA", "IPL-1.0", "ISC", "Intel", "LGPL-2.1", "LGPL-3.0", "LO-FR-2.0", "LPL-1.0", "LPL-1.02", "LPPL-1.3c", "MIT", "MPL-1.0", "MPL-1.1", "MPL-2.0", "MS-PL", "MS-RL", "MirOS", "Motosoto", "Multics", "NASA-1.3", "NCSA", "NGPL", "NPOSL-3.0", "NTP", "Naumen", "Nokia", "OCLC-2.0", "ODC-BY-1.0", "ODbL-1.0", "OFL-1.1", "OGL-Canada-2.0", "OGL-UK-1.0", "OGL-UK-2.0", "OGL-UK-3.0", "OGTSL", "OSL-3.0", "PDDL-1.0", "PHP-3.0", "PostgreSQL", "Python-2.0", "QPL-1.0", "RPL-1.5", "RPSL-1.0", "RSCPL", "SISSL", "SPL-1.0", "SimPL-2.0", "Sleepycat", "Talis", "Unlicense", "VSL-1.0", "W3C", "WXwindows", "Watcom-1.0", "Xnet", "ZPL-2.0", "Zlib", "dli-model-use", "geogratis", "hesa-withrights", "localauth-withrights", "met-office-cp", "mitre", "notspecified", "other-at", "other-closed", "other-nc", "other-open", "other-pd", "ukclickusepsi", "ukcrown", "ukcrown-withrights", "ukpsi"] - }, - "script": { - "$id": "#/properties/script", - "type": "array", - "uniqueItems": true, - "minItems": 1, - "title": "ISO 15924 scripts recognized by the model", - "items": { - "$id": "#/properties/script/items", - "type": "string", - "enum": ["Tang", "Xsux", "Xpeo", "Blis", "Ugar", "Egyp", "Brai", "Egyh", "Loma", "Egyd", "Hluw", "Maya", "Sgnw", "Inds", "Mero", "Merc", "Sarb", "Narb", "Roro", "Phnx", "Lydi", "Tfng", "Samr", "Armi", "Hebr", "Palm", "Hatr", "Prti", "Phli", "Phlp", "Phlv", "Avst", "Syrc", "Syrn", "Syrj", "Syre", "Mani", "Mand", "Mong", "Nbat", "Arab", "Aran", "Nkoo", "Adlm", "Thaa", "Orkh", "Hung", "Grek", "Cari", "Lyci", "Copt", "Goth", "Ital", "Runr", "Ogam", "Latn", "Latg", "Latf", "Moon", "Osge", "Cyrl", "Cyrs", "Glag", "Elba", "Perm", "Armn", "Aghb", "Geor", "Geok", "Dupl", "Dsrt", "Bass", "Osma", "Olck", "Wara", "Pauc", "Mroo", "Medf", "Visp", "Shaw", "Plrd", "Jamo", "Bopo", "Hang", "Kore", "Kits", "Teng", "Cirt", "Sara", "Piqd", "Brah", "Sidd", "Khar", "Guru", "Gong", "Gonm", "Mahj", "Deva", "Sylo", "Kthi", "Sind", "Shrd", "Gujr", "Takr", "Khoj", "Mult", "Modi", "Beng", "Tirh", "Orya", "Dogr", "Soyo", "Tibt", "Phag", "Marc", "Newa", "Bhks", "Lepc", "Limb", "Mtei", "Ahom", "Zanb", "Telu", "Gran", "Saur", "Knda", "Taml", "Mlym", "Sinh", "Cakm", "Mymr", "Lana", "Thai", "Tale", "Talu", "Khmr", "Laoo", "Kali", "Cham", "Tavt", "Bali", "Java", "Sund", "Rjng", "Leke", "Batk", "Maka", "Bugi", "Tglg", "Hano", "Buhd", "Tagb", "Qaaa", "Sora", "Lisu", "Lina", "Linb", "Cprt", "Hira", "Kana", "Hrkt", "Jpan", "Nkgb", "Ethi", "Bamu", "Kpel", "Qabx", "Mend", "Afak", "Cans", "Cher", "Hmng", "Yiii", "Vaii", "Wole", "Zsye", "Zinh", "Zmth", "Zsym", "Zxxx", "Zyyy", "Zzzz", "Nshu", "Hani", "Hans", "Hant", "Hanb", "Kitl", "Jurc"] - } - }, - "name": { - "$id": "#/properties/name", - "type": "string", - "title": "Filename of the model", - "pattern": "^(.*)$" - }, - "graphemes": { - "$id": "#/properties/graphemes", - "type": "array", - "title": "Code points recognizable by the model", - "uniqueItems": true, - "minItems": 1, - "items": { - "$id": "#/properties/graphemes/items", - "type": "string", - "pattern": "^(.*)$" - } - } - } -} diff --git a/kraken/repo.py b/kraken/repo.py index f902e9f8..5168da3c 100644 --- a/kraken/repo.py +++ b/kraken/repo.py @@ -1,5 +1,5 @@ # -# Copyright 2015 Benjamin Kiessling +# Copyright 2025 Benjamin Kiessling # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,262 +13,72 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """ -Accessors to the model repository on zenodo. -""" -import json -import logging -import os -from contextlib import closing -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable - -import requests - -from kraken.lib.exceptions import KrakenRepoException - -if TYPE_CHECKING: - from os import PathLike - - -__all__ = ['get_model', 'get_description', 'get_listing', 'publish_model'] +kraken.repo +~~~~~~~~~~~ -logger = logging.getLogger(__name__) - -MODEL_REPO = 'https://zenodo.org/api/' -SUPPORTED_MODELS = set(['kraken_pytorch']) - - -def publish_model(model_file: [str, 'PathLike'] = None, - metadata: dict = None, - access_token: str = None, - callback: Callable[[int, int], Any] = lambda: None, - private: bool = False) -> str: - """ - Publishes a model to the repository. - - Args: - model_file: Path to read model from. - metadata: Metadata dictionary - access_token: Zenodo API access token - callback: Function called with octet-wise progress. - private: Whether to generate a community inclusion request that makes - the model recoverable by the public. - """ - model_file = Path(model_file) - fp = open(model_file, 'rb') - _metadata = json.dumps(metadata) - total = model_file.stat().st_size + len(_metadata) + 3 - headers = {"Content-Type": "application/json"} - r = requests.post(f'{MODEL_REPO}deposit/depositions', - params={'access_token': access_token}, json={}, - headers=headers) - r.raise_for_status() - callback(total, 1) - deposition_id = r.json()['id'] - data = {'filename': 'metadata.json'} - files = {'file': ('metadata.json', _metadata)} - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/files', - params={'access_token': access_token}, data=data, - files=files) - r.raise_for_status() - callback(total, len(_metadata)) - data = {'filename': metadata['name']} - files = {'file': fp} - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/files', - params={'access_token': access_token}, data=data, - files=files) - r.raise_for_status() - callback(total, model_file.stat().st_size) - # fill zenodo metadata - data = {'metadata': { - 'title': metadata['summary'], - 'upload_type': 'publication', - 'publication_type': 'other', - 'description': metadata['description'], - 'creators': metadata['authors'], - 'access_right': 'open', - 'keywords': ['kraken_pytorch'], - 'license': metadata['license'] - } - } +Wrappers around the htrmopo reference implementation implementing +kraken-specific filtering for repository querying operations. +""" +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Dict, Optional, TypeVar, Literal - if not private: - data['metadata']['communities'] = [{'identifier': 'ocr_models'}] - # add link to training data to metadata - if 'source' in metadata: - data['metadata']['related_identifiers'] = [{'relation': 'isSupplementTo', 'identifier': metadata['source']}] - r = requests.put(f'{MODEL_REPO}deposit/depositions/{deposition_id}', - params={'access_token': access_token}, - data=json.dumps(data), - headers=headers) - r.raise_for_status() - callback(total, 1) - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/actions/publish', - params={'access_token': access_token}) - r.raise_for_status() - callback(total, 1) - return r.json()['doi'] +from htrmopo import get_description as mopo_get_description +from htrmopo import get_listing as mopo_get_listing +from htrmopo.record import v0RepositoryRecord, v1RepositoryRecord -def get_model(model_id: str, path: str, callback: Callable[[int, int], Any] = lambda total, advance: None) -> str: - """ - Retrieves a model and saves it to a path. +_v0_or_v1_Record = TypeVar('_v0_or_v1_Record', v0RepositoryRecord, v1RepositoryRecord) - Args: - model_id (str): DOI of the model - path (str): Destination to write model to. - callback (func): Function called for every 1024 octet chunk received. - Returns: - The identifier the model can be called through on the command line. - Will usually be the file name of the model. +def get_description(model_id: str, + callback: Callable[..., Any] = lambda: None, + version: Optional[Literal['v0', 'v1']] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> _v0_or_v1_Record: """ - logger.info(f'Saving model {model_id} to {path}') - r = requests.get(f'{MODEL_REPO}records', params={'q': f'doi:"{model_id}"', 'allversions': '1'}) - r.raise_for_status() - callback(0, 0) - resp = r.json() - if resp['hits']['total'] != 1: - logger.error(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - raise KrakenRepoException(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - - record = resp['hits']['hits'][0] - metadata_url = [x['links']['self'] for x in record['files'] if x['key'] == 'metadata.json'][0] - r = requests.get(metadata_url) - r.raise_for_status() - resp = r.json() - # callable model identifier - nat_id = resp['name'] - model_url = [x['links']['self'] for x in record['files'] if x['key'] == nat_id][0] - spath = os.path.join(path, nat_id) - logger.debug(f'downloading model file {model_url} to {spath}') - with closing(requests.get(model_url, stream=True)) as r: - file_size = int(r.headers['Content-length']) - with open(spath, 'wb') as f: - for chunk in r.iter_content(chunk_size=1024): - callback(file_size, len(chunk)) - f.write(chunk) - return nat_id - - -def get_description(model_id: str, callback: Callable[..., Any] = lambda: None) -> dict: - """ - Fetches the metadata for a single model from the zenodo repository. + Filters the output of htrmopo.get_description with a custom function. Args: - model_id (str): DOI of the model. - callback (callable): Optional function called once per HTTP request. - - Returns: - Dict + model_id: model DOI + callback: Progress callback + version: + filter_fn: Function called to filter the retrieved record. """ - logger.info(f'Retrieving metadata for {model_id}') - r = requests.get(f'{MODEL_REPO}records', params={'q': f'doi:"{model_id}"', 'allversions': '1'}) - r.raise_for_status() - callback() - resp = r.json() - if resp['hits']['total'] != 1: - logger.error(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - raise KrakenRepoException(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - record = resp['hits']['hits'][0] - metadata = record['metadata'] - if 'keywords' not in metadata: - logger.error('No keywords included on deposit') - raise KrakenRepoException('No keywords included on deposit.') - model_type = SUPPORTED_MODELS.intersection(metadata['keywords']) - if not model_type: - msg = 'Unsupported model type(s): {}'.format(', '.join(metadata['keywords'])) - logger.error(msg) - raise KrakenRepoException(msg) - meta_json = None - for file in record['files']: - for file in record['files']: - if file['key'] == 'metadata.json': - callback() - r = requests.get(file['links']['self']) - r.raise_for_status() - try: - meta_json = r.json() - except Exception: - msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format' - logger.error(msg) - raise KrakenRepoException(msg) - if not meta_json: - msg = 'Mo metadata.json found for \'{}\' ({})'.format(record['metadata']['title'], record['metadata']['doi']) - logger.error(msg) - raise KrakenRepoException(msg) - # merge metadata.json into DataCite - metadata.update({'graphemes': meta_json['graphemes'], - 'summary': meta_json['summary'], - 'script': meta_json['script'], - 'link': record['links']['latest'], - 'type': [x.split('_')[1] for x in model_type], - 'accuracy': meta_json['accuracy']}) - return metadata + desc = mopo_get_description(model_id, callback, version) + if not filter_fn(desc): + raise ValueError(f'Record {model_id} exists but is not a valid kraken record') + return desc -def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None) -> dict: +def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None, + from_date: Optional[str] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> Dict[str, Dict[str, _v0_or_v1_Record]]: """ - Fetches a listing of all kraken models from the zenodo repository. + Returns a filtered representation of the model repository grouped by + concept DOI. Args: - callback (Callable): Function called after each HTTP request. + callback: Progress callback + from_data: + filter_fn: Function called for each record object Returns: - Dict of models with each model. + A dictionary mapping group DOIs to one record object per deposit. The + record of the highest available schema version is retained. """ - logger.info('Retrieving model list') - records = [] - r = requests.get('{}{}'.format(MODEL_REPO, 'records'), params={'communities': 'ocr_models'}) - r.raise_for_status() - callback(1, 1) - resp = r.json() - if not resp['hits']['total']: - logger.error('No models found in community \'ocr_models\'') - raise KrakenRepoException('No models found in repository \'ocr_models\'') - logger.debug('Total of {} records in repository'.format(resp['hits']['total'])) - total = resp['hits']['total'] - callback(total, 0) - records.extend(resp['hits']['hits']) - if 'links' in resp and 'next' in resp['links']: - while 'next' in resp['links']: - logger.debug('Fetching next page') - r = requests.get(resp['links']['next']) - r.raise_for_status() - resp = r.json() - logger.debug('Found {} new records'.format(len(resp['hits']['hits']))) - records.extend(resp['hits']['hits']) - logger.debug('Retrieving model metadata') - models = {} - # fetch metadata.jsn for each model - for record in records: - if 'keywords' not in record['metadata']: - continue - model_type = SUPPORTED_MODELS.intersection(record['metadata']['keywords']) - if not model_type: - continue - metadata = None - for file in record['files']: - if file['key'] == 'metadata.json': - callback(total, 1) - r = requests.get(file['links']['self']) - r.raise_for_status() - try: - metadata = r.json() - except Exception: - msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format' - logger.error(msg) - raise KrakenRepoException(msg) - if not metadata: - logger.warning(f"No metadata found for record '{record['doi']}'.") - continue - # merge metadata.jsn into DataCite - key = record['metadata']['doi'] - models[key] = record['metadata'] - models[key].update({'graphemes': metadata['graphemes'], - 'summary': metadata['summary'], - 'script': metadata['script'], - 'link': record['links']['latest'], - 'type': [x.split('_')[1] for x in model_type]}) - return models + repository = mopo_get_listing(callback, from_date) + # aggregate models under their concept DOI + concepts = defaultdict(list) + for item in repository.values(): + # filter records here + item = {k: v for k, v in item.items() if filter_fn(v)} + # both got the same DOI information + record = item.get('v1', item.get('v0', None)) + if record is not None: + concepts[record.concept_doi].append(record) + + for k, v in concepts.items(): + concepts[k] = sorted(v, key=lambda x: x.publication_date, reverse=True) + + return concepts diff --git a/setup.cfg b/setup.cfg index 8e0e05b2..0b29c0b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,9 +58,11 @@ install_requires = scikit-image~=0.24.0 shapely>=2.0.6,~=2.0.6 pyarrow + htrmopo>=0.3,~=0.3 lightning~=2.4.0 torchmetrics>=1.1.0 threadpoolctl~=3.5.0 + platformdirs rich [options.extras_require] diff --git a/tests/test_repo.py b/tests/test_repo.py index 3b54f099..f7d3ff29 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -11,7 +11,7 @@ class TestRepo(unittest.TestCase): """ - Testing interaction with the model repository. + Testing our wrappers around HTRMoPo """ def setUp(self): @@ -33,30 +33,11 @@ def test_get_description(self): Tests fetching the description of a model. """ record = repo.get_description('10.5281/zenodo.8425684') - self.assertEqual(record['doi'], '10.5281/zenodo.8425684') - - def test_get_model(self): - """ - Tests fetching a model. - """ - id = repo.get_model('10.5281/zenodo.8425684', - path=self.temp_model.name) - self.assertEqual(id, 'omnisyr_best.mlmodel') - self.assertEqual((self.temp_path / id).stat().st_size, 16245671) + self.assertEqual(record.doi, '10.5281/zenodo.8425684') def test_prev_record_version_get_description(self): """ Tests fetching the description of a model that has a superseding newer version. """ record = repo.get_description('10.5281/zenodo.6657809') - self.assertEqual(record['doi'], '10.5281/zenodo.6657809') - - def test_prev_record_version_get_model(self): - """ - Tests fetching a model that has a superseding newer version. - """ - id = repo.get_model('10.5281/zenodo.6657809', - path=self.temp_model.name) - self.assertEqual(id, 'HTR-United-Manu_McFrench.mlmodel') - self.assertEqual((self.temp_path / id).stat().st_size, 16176844) - + self.assertEqual(record.doi, '10.5281/zenodo.6657809')