From ffb1ef8dda81217456cba49f18ef6852365e5ff1 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Fri, 3 Jan 2025 23:16:06 +0100 Subject: [PATCH 1/9] wip for new model repository --- .github/workflows/test.yml | 3 - environment.yml | 1 + environment_cuda.yml | 1 + kraken/ketos/repo.py | 169 ++++++++++++++--------- kraken/kraken.py | 136 +++++++++++++++--- kraken/repo.py | 274 ------------------------------------- setup.cfg | 1 + 7 files changed, 226 insertions(+), 359 deletions(-) delete mode 100644 kraken/repo.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eb2fc9833..e609452dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,9 +52,6 @@ jobs: python -m build --sdist --wheel --outdir dist/ . - name: Publish a Python distribution to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} - name: Upload PyPI artifacts to GH storage uses: actions/upload-artifact@v3 with: diff --git a/environment.yml b/environment.yml index 6d192901b..c624a8a21 100644 --- a/environment.yml +++ b/environment.yml @@ -32,4 +32,5 @@ dependencies: - setuptools>=36.6.0,<70.0.0 - pip: - coremltools~=8.1 + - htrmopo - file:. diff --git a/environment_cuda.yml b/environment_cuda.yml index d9525927b..243a9b7c5 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -33,4 +33,5 @@ dependencies: - setuptools>=36.6.0,<70.0.0 - pip: - coremltools~=8.1 + - htrmopo - file:. diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index fe67bf803..ea2f4dd91 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -18,94 +18,139 @@ Command line driver for publishing models to the model repository. """ +import re import logging -import os import click +from pathlib import Path from .util import message logging.captureWarnings(True) logger = logging.getLogger('kraken') +def _get_field_list(name): + values = [] + while True: + value = click.prompt(name, default=None) + if value is not None: + values.append(value) + else: + break + return values + + @click.command('publish') @click.pass_context @click.option('-i', '--metadata', show_default=True, - type=click.File(mode='r', lazy=True), help='Metadata for the ' - 'model. Will be prompted from the user if not given') + type=click.File(mode='r', lazy=True), help='Model card file for the model.') @click.option('-a', '--access-token', prompt=True, help='Zenodo access token') +@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update') @click.option('-p', '--private/--public', default=False, help='Disables Zenodo ' 'community inclusion request. Allows upload of models that will not show ' 'up on `kraken list` output') @click.argument('model', nargs=1, type=click.Path(exists=False, readable=True, dir_okay=False)) -def publish(ctx, metadata, access_token, private, model): +def publish(ctx, metadata, access_token, doi, private, model): """ Publishes a model on the zenodo model repository. """ import json + import tempfile + + from htrmopo import publish_model, update_model - from importlib import resources - from jsonschema import validate - from jsonschema.exceptions import ValidationError + pub_fn = publish_model - from kraken import repo - from kraken.lib import models + from kraken.lib.vgsl import TorchVGSLModel from kraken.lib.progress import KrakenDownloadProgressBar - ref = resources.files('kraken').joinpath('metadata.schema.json') - with open(ref, 'rb') as fp: - schema = json.load(fp) - - nn = models.load_any(model) - - if not metadata: - author = click.prompt('author') - affiliation = click.prompt('affiliation') - summary = click.prompt('summary') - description = click.edit('Write long form description (training data, transcription standards) of the model here') - accuracy_default = None - # take last accuracy measurement in model metadata - if 'accuracy' in nn.nn.user_metadata and nn.nn.user_metadata['accuracy']: - accuracy_default = nn.nn.user_metadata['accuracy'][-1][1] * 100 - accuracy = click.prompt('accuracy on test set', type=float, default=accuracy_default) - script = [ - click.prompt( - 'script', - type=click.Choice( - sorted( - schema['properties']['script']['items']['enum'])), - show_choices=True)] - license = click.prompt( - 'license', - type=click.Choice( - sorted( - schema['properties']['license']['enum'])), - show_choices=True) - metadata = { - 'authors': [{'name': author, 'affiliation': affiliation}], - 'summary': summary, - 'description': description, - 'accuracy': accuracy, - 'license': license, - 'script': script, - 'name': os.path.basename(model), - 'graphemes': ['a'] - } - while True: - try: - validate(metadata, schema) - except ValidationError as e: - message(e.message) - metadata[e.path[-1]] = click.prompt(e.path[-1], type=float if e.schema['type'] == 'number' else str) - continue - break + _yaml_delim = r'(?:---|\+\+\+)' + _yaml = r'(.*?)' + _content = r'\s*(.+)$' + _re_pattern = r'^\s*' + _yaml_delim + _yaml + _yaml_delim + _content + _yaml_regex = re.compile(_re_pattern, re.S | re.M) + nn = TorchVGSLModel.load_model(model) + + frontmatter = {} + # construct metadata if none is given + if metadata: + frontmatter, content = _yaml_regex.match(metadata.read()).groups() else: - metadata = json.load(metadata) - validate(metadata, schema) - metadata['graphemes'] = [char for char in ''.join(nn.codec.c2l.keys())] - with KrakenDownloadProgressBar() as progress: + frontmatter['summary'] = click.prompt('summary') + content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here') + + creators = [] + while True: + author = click.prompt('author', default=None) + affiliation = click.prompt('affiliation', default=None) + orcid = click.prompt('orcid', default=None) + if author is not None: + creators.append({'author': author}) + else: + break + if affiliation is not None: + creators[-1]['affiliation'] = affiliation + if orcid is not None: + creators[-1]['orcid'] = orcid + frontmatter['authors'] = creators + frontmatter['license'] = click.prompt('license') + frontmatter['language'] = _get_field_list('language') + frontmatter['script'] = _get_field_list('script') + + if len(tags := _get_field_list('tag')): + frontmatter['tags'] = tags + ['kraken_pytorch'] + if len(datasets := _get_field_list('dataset URL')): + frontmatter['datasets'] = datasets + if len(base_model := _get_field_list('base model URL')): + frontmatter['base_model'] = base_model + + # take last metrics field, falling back to accuracy field in model metadata + metrics = {} + if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']: + metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy'] + metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy'] + elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']: + metrics['cer'] = 100 - nn.user_metadata['accuracy'] + frontmatter['metrics'] = metrics + software_hints = ['kind=vgsl'] + + # some recognition-specific software hints + if nn.model_type == 'recognition': + software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) + frontmatter['software_hints'] = software_hints + + frontmatter['software_name'] = 'kraken' + + # build temporary directory + with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress: upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False) - oid = repo.publish_model(model, metadata, access_token, lambda total, advance: progress.update(upload_task, total=total, advance=advance), private) - message('model PID: {}'.format(oid)) + + model = Path(model) + tmpdir = Path(tmpdir) + (tmpdir / model.name).symlink_to(model) + # v0 metadata only supports recognition models + if nn.model_type == 'recognition': + v0_metadata = { + 'summary': frontmatter['summary'], + 'description': content, + 'license': frontmatter['license'], + 'script': frontmatter['script'], + 'name': model.name, + 'graphemes': [char for char in ''.join(nn.codec.c2l.keys())] + } + if frontmatter['metrics']: + v0_metadata['accuracy'] = 100 - metrics['cer'] + with open(tmpdir / 'metadata.json', 'w') as fo: + json.dump(v0_metadata, fo) + kwargs = {'model': tmpdir, + 'model_card': f'---\n{frontmatter}---\n{content}', + 'access_token': access_token, + 'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance), + 'private': private} + if doi: + pub_fn = update_model + kwargs['model_id'] = doi + oid = pub_fn(**kwargs) + message(f'model PID: {oid}') diff --git a/kraken/kraken.py b/kraken/kraken.py index 23a12daf9..a41c41f50 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -31,7 +31,15 @@ import click from PIL import Image from importlib import resources + +from rich import print +from rich.tree import Tree +from rich.table import Table +from rich.console import Group from rich.traceback import install +from rich.logging import RichHandler +from rich.markdown import Markdown +from rich.progress import Progress from kraken.lib import log @@ -677,29 +685,90 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): @cli.command('show') @click.pass_context +@click.option('-V', '--metadata-version', + default='highest', + type=click.Choice(['v0', 'v1', 'highest']), + help='Version of metadata to fetch if multiple exist in repository.') @click.argument('model_id') -def show(ctx, model_id): +def show(ctx, metadata_version, model_id): """ Retrieves model metadata from the repository. """ - from kraken import repo + from htrmopo import get_description + from htrmopo.util import iso15924_to_name, iso639_3_to_name from kraken.lib.util import is_printable, make_printable - desc = repo.get_description(model_id) + def _render_creators(creators): + o = [] + for creator in creators: + c_text = creator['name'] + if (orcid := creator.get('orcid', None)) is not None: + c_text += f' ({orcid})' + if (affiliation := creator.get('affiliation', None)) is not None: + c_text += f' ({affiliation})' + o.append(c_text) + return o - chars = [] - combining = [] - for char in sorted(desc['graphemes']): - if not is_printable(char): - combining.append(make_printable(char)) - else: - chars.append(char) - message( - 'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}'.format( - model_id, desc['summary'], desc['description'], ' '.join( - desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join( - x['name'] for x in desc['creators']), desc['publication_date'])) - ctx.exit(0) + def _render_metrics(metrics): + return [f'{k}: {v:.2f}' for k, v in metrics.items()] + + if metadata_version == 'highest': + metadata_version = None + + try: + desc = get_description(model_id, version=metadata_version) + except ValueError as e: + logger.error(e) + ctx.exit(1) + + if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: + logger.error('Record exists but is not a kraken-compatible model') + ctx.exit(1) + + if desc.version == 'v0': + chars = [] + combining = [] + for char in sorted(desc.graphemes): + if not is_printable(char): + combining.append(make_printable(char)) + else: + chars.append(char) + + table = Table(title=desc.summary, show_header=False) + table.add_column('key', justify="left", no_wrap=True) + table.add_column('value', justify="left", no_wrap=False) + table.add_row('DOI', desc.doi) + table.add_row('concept DOI', desc.concept_doi) + table.add_row('publication date', desc.publication_date.isoformat()) + table.add_row('model type', Group(*desc.model_type)) + table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script])) + table.add_row('alphabet', Group(' '.join(chars), ', '.join(combining))) + table.add_row('keywords', Group(*desc.keywords)) + table.add_row('metrics', Group(*_render_metrics(desc.metrics))) + table.add_row('license', desc.license) + table.add_row('creators', Group(*_render_creators(desc.creators))) + table.add_row('description', desc.description) + elif desc.version == 'v1': + table = Table(title=desc.summary, show_header=False) + table.add_column('key', justify="left", no_wrap=True) + table.add_column('value', justify="left", no_wrap=False) + table.add_row('DOI', desc.doi) + table.add_row('concept DOI', desc.concept_doi) + table.add_row('publication date', desc.publication_date.isoformat()) + table.add_row('model type', Group(*desc.model_type)) + table.add_row('language', Group(*[iso639_3_to_name(x) for x in desc.language])) + table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script])) + table.add_row('keywords', Group(*desc.keywords)) + table.add_row('datasets', Group(*desc.datasets)) + table.add_row('metrics', Group(*_render_metrics(desc.metrics))) + table.add_row('base model', Group(*desc.base_model)) + table.add_row('software', desc.software_name) + table.add_row('software_hints', Group(*desc.software_hints)) + table.add_row('license', desc.license) + table.add_row('creators', Group(*_render_creators(desc.creators))) + table.add_row('description', Markdown(desc.description)) + + print(table) @cli.command('list') @@ -708,14 +777,41 @@ def list_models(ctx): """ Lists models in the repository. """ - from kraken import repo + from htrmopo import get_listing + from collections import defaultdict from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) - model_list = repo.get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance)) - for id, metadata in model_list.items(): - message('{} ({}) - {}'.format(id, ', '.join(metadata['type']), metadata['summary'])) + repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance)) + # aggregate models under their concept DOI + concepts = defaultdict(list) + for item in repository.values(): + # both got the same DOI information + record = item['v0'] if item['v0'] else item['v1'] + concepts[record.concept_doi].append(record.doi) + + table = Table(show_header=True) + table.add_column('DOI', justify="left", no_wrap=True) + table.add_column('summary', justify="left", no_wrap=False) + table.add_column('model type', justify="left", no_wrap=False) + table.add_column('keywords', justify="left", no_wrap=False) + + for k, v in concepts.items(): + records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v] + records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records) + records = sorted(records, key=lambda x: x.publication_date, reverse=True) + if not len(records): + continue + + t = Tree(k) + [t.add(x.doi) for x in records] + table.add_row(t, + Group(*[''] + [x.summary for x in records]), + Group(*[''] + ['; '.join(x.model_type) for x in records]), + Group(*[''] + ['; '.join(x.keywords) for x in records])) + + print(table) ctx.exit(0) diff --git a/kraken/repo.py b/kraken/repo.py deleted file mode 100644 index f902e9f80..000000000 --- a/kraken/repo.py +++ /dev/null @@ -1,274 +0,0 @@ -# -# Copyright 2015 Benjamin Kiessling -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -""" -Accessors to the model repository on zenodo. -""" -import json -import logging -import os -from contextlib import closing -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable - -import requests - -from kraken.lib.exceptions import KrakenRepoException - -if TYPE_CHECKING: - from os import PathLike - - -__all__ = ['get_model', 'get_description', 'get_listing', 'publish_model'] - -logger = logging.getLogger(__name__) - -MODEL_REPO = 'https://zenodo.org/api/' -SUPPORTED_MODELS = set(['kraken_pytorch']) - - -def publish_model(model_file: [str, 'PathLike'] = None, - metadata: dict = None, - access_token: str = None, - callback: Callable[[int, int], Any] = lambda: None, - private: bool = False) -> str: - """ - Publishes a model to the repository. - - Args: - model_file: Path to read model from. - metadata: Metadata dictionary - access_token: Zenodo API access token - callback: Function called with octet-wise progress. - private: Whether to generate a community inclusion request that makes - the model recoverable by the public. - """ - model_file = Path(model_file) - fp = open(model_file, 'rb') - _metadata = json.dumps(metadata) - total = model_file.stat().st_size + len(_metadata) + 3 - headers = {"Content-Type": "application/json"} - r = requests.post(f'{MODEL_REPO}deposit/depositions', - params={'access_token': access_token}, json={}, - headers=headers) - r.raise_for_status() - callback(total, 1) - deposition_id = r.json()['id'] - data = {'filename': 'metadata.json'} - files = {'file': ('metadata.json', _metadata)} - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/files', - params={'access_token': access_token}, data=data, - files=files) - r.raise_for_status() - callback(total, len(_metadata)) - data = {'filename': metadata['name']} - files = {'file': fp} - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/files', - params={'access_token': access_token}, data=data, - files=files) - r.raise_for_status() - callback(total, model_file.stat().st_size) - # fill zenodo metadata - data = {'metadata': { - 'title': metadata['summary'], - 'upload_type': 'publication', - 'publication_type': 'other', - 'description': metadata['description'], - 'creators': metadata['authors'], - 'access_right': 'open', - 'keywords': ['kraken_pytorch'], - 'license': metadata['license'] - } - } - - if not private: - data['metadata']['communities'] = [{'identifier': 'ocr_models'}] - - # add link to training data to metadata - if 'source' in metadata: - data['metadata']['related_identifiers'] = [{'relation': 'isSupplementTo', 'identifier': metadata['source']}] - r = requests.put(f'{MODEL_REPO}deposit/depositions/{deposition_id}', - params={'access_token': access_token}, - data=json.dumps(data), - headers=headers) - r.raise_for_status() - callback(total, 1) - r = requests.post(f'{MODEL_REPO}deposit/depositions/{deposition_id}/actions/publish', - params={'access_token': access_token}) - r.raise_for_status() - callback(total, 1) - return r.json()['doi'] - - -def get_model(model_id: str, path: str, callback: Callable[[int, int], Any] = lambda total, advance: None) -> str: - """ - Retrieves a model and saves it to a path. - - Args: - model_id (str): DOI of the model - path (str): Destination to write model to. - callback (func): Function called for every 1024 octet chunk received. - - Returns: - The identifier the model can be called through on the command line. - Will usually be the file name of the model. - """ - logger.info(f'Saving model {model_id} to {path}') - r = requests.get(f'{MODEL_REPO}records', params={'q': f'doi:"{model_id}"', 'allversions': '1'}) - r.raise_for_status() - callback(0, 0) - resp = r.json() - if resp['hits']['total'] != 1: - logger.error(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - raise KrakenRepoException(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - - record = resp['hits']['hits'][0] - metadata_url = [x['links']['self'] for x in record['files'] if x['key'] == 'metadata.json'][0] - r = requests.get(metadata_url) - r.raise_for_status() - resp = r.json() - # callable model identifier - nat_id = resp['name'] - model_url = [x['links']['self'] for x in record['files'] if x['key'] == nat_id][0] - spath = os.path.join(path, nat_id) - logger.debug(f'downloading model file {model_url} to {spath}') - with closing(requests.get(model_url, stream=True)) as r: - file_size = int(r.headers['Content-length']) - with open(spath, 'wb') as f: - for chunk in r.iter_content(chunk_size=1024): - callback(file_size, len(chunk)) - f.write(chunk) - return nat_id - - -def get_description(model_id: str, callback: Callable[..., Any] = lambda: None) -> dict: - """ - Fetches the metadata for a single model from the zenodo repository. - - Args: - model_id (str): DOI of the model. - callback (callable): Optional function called once per HTTP request. - - Returns: - Dict - """ - logger.info(f'Retrieving metadata for {model_id}') - r = requests.get(f'{MODEL_REPO}records', params={'q': f'doi:"{model_id}"', 'allversions': '1'}) - r.raise_for_status() - callback() - resp = r.json() - if resp['hits']['total'] != 1: - logger.error(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - raise KrakenRepoException(f'Found {resp["hits"]["total"]} models when querying for id \'{model_id}\'') - record = resp['hits']['hits'][0] - metadata = record['metadata'] - if 'keywords' not in metadata: - logger.error('No keywords included on deposit') - raise KrakenRepoException('No keywords included on deposit.') - model_type = SUPPORTED_MODELS.intersection(metadata['keywords']) - if not model_type: - msg = 'Unsupported model type(s): {}'.format(', '.join(metadata['keywords'])) - logger.error(msg) - raise KrakenRepoException(msg) - meta_json = None - for file in record['files']: - for file in record['files']: - if file['key'] == 'metadata.json': - callback() - r = requests.get(file['links']['self']) - r.raise_for_status() - try: - meta_json = r.json() - except Exception: - msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format' - logger.error(msg) - raise KrakenRepoException(msg) - if not meta_json: - msg = 'Mo metadata.json found for \'{}\' ({})'.format(record['metadata']['title'], record['metadata']['doi']) - logger.error(msg) - raise KrakenRepoException(msg) - # merge metadata.json into DataCite - metadata.update({'graphemes': meta_json['graphemes'], - 'summary': meta_json['summary'], - 'script': meta_json['script'], - 'link': record['links']['latest'], - 'type': [x.split('_')[1] for x in model_type], - 'accuracy': meta_json['accuracy']}) - return metadata - - -def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None) -> dict: - """ - Fetches a listing of all kraken models from the zenodo repository. - - Args: - callback (Callable): Function called after each HTTP request. - - Returns: - Dict of models with each model. - """ - logger.info('Retrieving model list') - records = [] - r = requests.get('{}{}'.format(MODEL_REPO, 'records'), params={'communities': 'ocr_models'}) - r.raise_for_status() - callback(1, 1) - resp = r.json() - if not resp['hits']['total']: - logger.error('No models found in community \'ocr_models\'') - raise KrakenRepoException('No models found in repository \'ocr_models\'') - logger.debug('Total of {} records in repository'.format(resp['hits']['total'])) - total = resp['hits']['total'] - callback(total, 0) - records.extend(resp['hits']['hits']) - if 'links' in resp and 'next' in resp['links']: - while 'next' in resp['links']: - logger.debug('Fetching next page') - r = requests.get(resp['links']['next']) - r.raise_for_status() - resp = r.json() - logger.debug('Found {} new records'.format(len(resp['hits']['hits']))) - records.extend(resp['hits']['hits']) - logger.debug('Retrieving model metadata') - models = {} - # fetch metadata.jsn for each model - for record in records: - if 'keywords' not in record['metadata']: - continue - model_type = SUPPORTED_MODELS.intersection(record['metadata']['keywords']) - if not model_type: - continue - metadata = None - for file in record['files']: - if file['key'] == 'metadata.json': - callback(total, 1) - r = requests.get(file['links']['self']) - r.raise_for_status() - try: - metadata = r.json() - except Exception: - msg = f'Metadata for \'{record["metadata"]["title"]}\' ({record["metadata"]["doi"]}) not in JSON format' - logger.error(msg) - raise KrakenRepoException(msg) - if not metadata: - logger.warning(f"No metadata found for record '{record['doi']}'.") - continue - # merge metadata.jsn into DataCite - key = record['metadata']['doi'] - models[key] = record['metadata'] - models[key].update({'graphemes': metadata['graphemes'], - 'summary': metadata['summary'], - 'script': metadata['script'], - 'link': record['links']['latest'], - 'type': [x.split('_')[1] for x in model_type]}) - return models diff --git a/setup.cfg b/setup.cfg index 8e0e05b2e..efc7000ba 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,7 @@ install_requires = scikit-image~=0.24.0 shapely>=2.0.6,~=2.0.6 pyarrow + htrmopo lightning~=2.4.0 torchmetrics>=1.1.0 threadpoolctl~=3.5.0 From 716c520159df9ad9950ce37e4decc17df36316dc Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sat, 4 Jan 2025 21:03:09 +0100 Subject: [PATCH 2/9] Factor out htrmopo calls to include filters --- kraken/kraken.py | 56 ++++++++++++++----------------- kraken/repo.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 31 deletions(-) create mode 100644 kraken/repo.py diff --git a/kraken/kraken.py b/kraken/kraken.py index a41c41f50..ed01e9a86 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -694,8 +694,8 @@ def show(ctx, metadata_version, model_id): """ Retrieves model metadata from the repository. """ - from htrmopo import get_description from htrmopo.util import iso15924_to_name, iso639_3_to_name + from kraken.repo import get_description from kraken.lib.util import is_printable, make_printable def _render_creators(creators): @@ -716,15 +716,13 @@ def _render_metrics(metrics): metadata_version = None try: - desc = get_description(model_id, version=metadata_version) + desc = get_description(model_id, + version=metadata_version, + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) except ValueError as e: logger.error(e) ctx.exit(1) - if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: - logger.error('Record exists but is not a kraken-compatible model') - ctx.exit(1) - if desc.version == 'v0': chars = [] combining = [] @@ -777,19 +775,13 @@ def list_models(ctx): """ Lists models in the repository. """ - from htrmopo import get_listing - from collections import defaultdict + from kraken.repo import get_listing from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) - repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance)) - # aggregate models under their concept DOI - concepts = defaultdict(list) - for item in repository.values(): - # both got the same DOI information - record = item['v0'] if item['v0'] else item['v1'] - concepts[record.concept_doi].append(record.doi) + repository = get_listing(callback=lambda total, advance: progress.update(download_task, total=total, advance=advance), + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) table = Table(show_header=True) table.add_column('DOI', justify="left", no_wrap=True) @@ -797,13 +789,7 @@ def list_models(ctx): table.add_column('model type', justify="left", no_wrap=False) table.add_column('keywords', justify="left", no_wrap=False) - for k, v in concepts.items(): - records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v] - records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records) - records = sorted(records, key=lambda x: x.publication_date, reverse=True) - if not len(records): - continue - + for k, records in repository.items(): t = Tree(k) [t.add(x.doi) for x in records] table.add_row(t, @@ -812,7 +798,6 @@ def list_models(ctx): Group(*[''] + ['; '.join(x.keywords) for x in records])) print(table) - ctx.exit(0) @cli.command('get') @@ -822,20 +807,29 @@ def get(ctx, model_id): """ Retrieves a model from the repository. """ - from kraken import repo + import glob + + from htrmopo import get_model, get_description + from kraken.lib.progress import KrakenDownloadProgressBar try: - os.makedirs(click.get_app_dir(APP_NAME)) - except OSError: - pass + desc = get_description(model_id) + except ValueError as e: + logger.error(e) + ctx.exit(1) + + print(desc) + if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: + logger.error('Record exists but is not a kraken-compatible model') + ctx.exit(1) with KrakenDownloadProgressBar() as progress: download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False) - filename = repo.get_model(model_id, click.get_app_dir(APP_NAME), - lambda total, advance: progress.update(download_task, total=total, advance=advance)) - message(f'Model name: {filename}') - ctx.exit(0) + model_dir = get_model(model_id, + lambda total, advance: progress.update(download_task, total=total, advance=advance)) + model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iter_dir())) + message(f'Model dir: {model_dir} (model files: {model_candidates})') if __name__ == '__main__': diff --git a/kraken/repo.py b/kraken/repo.py new file mode 100644 index 000000000..f283deb72 --- /dev/null +++ b/kraken/repo.py @@ -0,0 +1,87 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +kraken.repo +~~~~~~~~~~~ + +Wrappers around the htrmopo reference implementation implementing +kraken-specific filtering. +""" +import logging +import warnings +from pathlib import Path +from collections import defaultdict +from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal + +from collections.abc import Callable + +from htrmopo import get_description as mopo_get_description +from htrmopo import get_listing as mopo_get_listing +from htrmopo.record import v0RepositoryRecord, v1RepositoryRecord + + +_v0_or_v1_Record = TypeVar('_v0_or_v1_Record', v0RepositoryRecord, v1RepositoryRecord) + + +def get_description(model_id: str, + callback: Callable[..., Any] = lambda: None, + version: Optional[Literal['v0', 'v1']] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> _v0_or_v1_Record: + """ + Filters the output of htrmopo.get_description with a custom function. + + Args: + model_id: model DOI + callback: Progress callback + version: + filter_fn: Function called to filter the retrieved record. + """ + desc = mopo_get_description(model_id, callback, version) + if not filter_fn(desc): + raise ValueError(f'Record {model_id} exists but is not a valid kraken record') + return desc + + +def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None, + from_date: Optional[str] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> Dict[str, Dict[str, _v0_or_v1_Record]]: + """ + Returns a filtered representation of the model repository grouped by + concept DOI. + + Args: + callback: Progress callback + from_data: + filter_fn: Function called for each record object + + Returns: + A dictionary mapping group DOIs to one record object per deposit. The + record of the highest available schema version is retained. + """ + repository = mopo_get_listing(callback, from_date) + # aggregate models under their concept DOI + concepts = defaultdict(list) + for item in repository.values(): + # filter records here + item = {k: v for k, v in item.items() if filter_fn(v)} + # both got the same DOI information + record = item.get('v1', item.get('v0', None)) + if record is not None: + concepts[record.concept_doi].append(record) + + for k, v in concepts.items(): + concepts[k] = sorted(v, key=lambda x: x.publication_date, reverse=True) + + return concepts From 53fd619718029962b06b788e03ea731c5b77035d Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 5 Jan 2025 02:35:56 +0100 Subject: [PATCH 3/9] more input validation --- kraken/ketos/repo.py | 110 ++++++++++++++++++++++++++++++++----------- kraken/repo.py | 11 ++--- 2 files changed, 87 insertions(+), 34 deletions(-) diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index ea2f4dd91..00e48eed8 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -24,20 +24,55 @@ import click from pathlib import Path +from difflib import get_close_matches + from .util import message logging.captureWarnings(True) logger = logging.getLogger('kraken') -def _get_field_list(name): +def _validate_script(script: str) -> str: + from htrmopo.util import _iso15924 + if script not in _iso15924: + return get_close_matches(script, _iso15924.keys()) + return script + + +def _validate_language(language: str) -> str: + from htrmopo.util import _iso639_3 + if language not in _iso639_3: + return get_close_matches(language, _iso639_3.keys()) + return language + + +def _validate_license(license: str) -> str: + from htrmopo.util import _licenses + if license not in _licenses: + return get_close_matches(license, _licenses.keys()) + return license + + +def _get_field_list(name, + validation_fn=lambda x: x, + required: bool = False): values = [] while True: - value = click.prompt(name, default=None) - if value is not None: - values.append(value) + value = click.prompt(name, default='') + if value: + if (cand := validation_fn(value)) == value: + values.append(value) + else: + message(f'Not a valid {name} value. Did you mean {cand}?') else: - break + if click.confirm(f'All `{name}` values added?'): + if required and not values: + message(f'`{name}` is a required field.') + continue + else: + break + else: + continue return values @@ -46,7 +81,7 @@ def _get_field_list(name): @click.option('-i', '--metadata', show_default=True, type=click.File(mode='r', lazy=True), help='Model card file for the model.') @click.option('-a', '--access-token', prompt=True, help='Zenodo access token') -@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update') +@click.option('-d', '--doi', help='DOI of an existing record to update') @click.option('-p', '--private/--public', default=False, help='Disables Zenodo ' 'community inclusion request. Allows upload of models that will not show ' 'up on `kraken list` output') @@ -56,15 +91,16 @@ def publish(ctx, metadata, access_token, doi, private, model): Publishes a model on the zenodo model repository. """ import json + import yaml import tempfile from htrmopo import publish_model, update_model - pub_fn = publish_model - from kraken.lib.vgsl import TorchVGSLModel from kraken.lib.progress import KrakenDownloadProgressBar + pub_fn = publish_model + _yaml_delim = r'(?:---|\+\+\+)' _yaml = r'(.*?)' _content = r'\s*(.+)$' @@ -77,27 +113,44 @@ def publish(ctx, metadata, access_token, doi, private, model): # construct metadata if none is given if metadata: frontmatter, content = _yaml_regex.match(metadata.read()).groups() + frontmatter = yaml.safe_load(frontmatter) else: frontmatter['summary'] = click.prompt('summary') content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here') creators = [] + message('To stop adding authors, leave the author name field empty.') while True: - author = click.prompt('author', default=None) - affiliation = click.prompt('affiliation', default=None) - orcid = click.prompt('orcid', default=None) - if author is not None: - creators.append({'author': author}) + author = click.prompt('author name', default='') + if author: + creators.append({'name': author}) else: - break + if click.confirm('All authors added?'): + break + else: + continue + affiliation = click.prompt('affiliation', default='') + orcid = click.prompt('orcid', default='') if affiliation is not None: creators[-1]['affiliation'] = affiliation if orcid is not None: creators[-1]['orcid'] = orcid + if not creators: + raise click.UsageError('The `authors` field is obligatory. Aborting') + frontmatter['authors'] = creators - frontmatter['license'] = click.prompt('license') - frontmatter['language'] = _get_field_list('language') - frontmatter['script'] = _get_field_list('script') + while True: + license = click.prompt('license') + if (lic := _validate_license(license)) == license: + frontmatter['license'] = license + break + else: + message(f'Not a valid license identifer. Did you mean {lic}?') + + message('To stop adding values to the following fields, enter an empty field.') + + frontmatter['language'] = _get_field_list('language', _validate_language, required=True) + frontmatter['script'] = _get_field_list('script', _validate_script, required=True) if len(tags := _get_field_list('tag')): frontmatter['tags'] = tags + ['kraken_pytorch'] @@ -108,30 +161,33 @@ def publish(ctx, metadata, access_token, doi, private, model): # take last metrics field, falling back to accuracy field in model metadata metrics = {} - if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']: - metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy'] - metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy'] - elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']: - metrics['cer'] = 100 - nn.user_metadata['accuracy'] + if nn.user_metadata.get('metrics', None) is not None: + if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None: + metrics['cer'] = 100 - (val_accuracy * 100) + if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None: + metrics['wer'] = 100 - (val_word_accuracy * 100) + elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None: + metrics['cer'] = 100 - accuracy frontmatter['metrics'] = metrics software_hints = ['kind=vgsl'] # some recognition-specific software hints if nn.model_type == 'recognition': - software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) + software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) frontmatter['software_hints'] = software_hints frontmatter['software_name'] = 'kraken' + frontmatter['model_type'] = [nn.model_type] # build temporary directory with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress: upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False) - model = Path(model) + model = Path(model).resolve() tmpdir = Path(tmpdir) - (tmpdir / model.name).symlink_to(model) - # v0 metadata only supports recognition models + (tmpdir / model.name).resolve().symlink_to(model) if nn.model_type == 'recognition': + # v0 metadata only supports recognition models v0_metadata = { 'summary': frontmatter['summary'], 'description': content, @@ -145,7 +201,7 @@ def publish(ctx, metadata, access_token, doi, private, model): with open(tmpdir / 'metadata.json', 'w') as fo: json.dump(v0_metadata, fo) kwargs = {'model': tmpdir, - 'model_card': f'---\n{frontmatter}---\n{content}', + 'model_card': f'---\n{yaml.dump(frontmatter)}---\n{content}', 'access_token': access_token, 'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance), 'private': private} diff --git a/kraken/repo.py b/kraken/repo.py index f283deb72..5168da3cd 100644 --- a/kraken/repo.py +++ b/kraken/repo.py @@ -1,5 +1,5 @@ # -# Copyright 2015 Benjamin Kiessling +# Copyright 2025 Benjamin Kiessling # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,12 @@ ~~~~~~~~~~~ Wrappers around the htrmopo reference implementation implementing -kraken-specific filtering. +kraken-specific filtering for repository querying operations. """ -import logging -import warnings -from pathlib import Path from collections import defaultdict -from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal - from collections.abc import Callable +from typing import Any, Dict, Optional, TypeVar, Literal + from htrmopo import get_description as mopo_get_description from htrmopo import get_listing as mopo_get_listing From 77e1d44c86d52b4a8e496ef756fdfd70bc861894 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 5 Jan 2025 02:51:21 +0100 Subject: [PATCH 4/9] remove obsolete metadata schema --- kraken/metadata.schema.json | 103 ------------------------------------ 1 file changed, 103 deletions(-) delete mode 100644 kraken/metadata.schema.json diff --git a/kraken/metadata.schema.json b/kraken/metadata.schema.json deleted file mode 100644 index a7483f69a..000000000 --- a/kraken/metadata.schema.json +++ /dev/null @@ -1,103 +0,0 @@ - { - "definitions": {}, - "$schema": "http://json-schema.org/draft-07/schema#", - "$id": "http://example.com/root.json", - "type": "object", - "title": "The Root Schema", - "required": [ - "authors", - "summary", - "description", - "accuracy", - "license", - "script", - "name", - "graphemes" - ], - "properties": { - "authors": { - "$id": "#/properties/authors", - "type": "array", - "title": "Authors of the model", - "items": { - "$id": "#/properties/authors/items", - "type": "object", - "title": "items", - "required": [ - "name", - "affiliation" - ], - "properties": { - "name": { - "$id": "#/properties/authors/items/properties/name", - "type": "string", - "title": "A single author's name", - "pattern": "^(.*)$" - }, - "affiliation": { - "$id": "#/properties/authors/items/properties/affiliation", - "type": "string", - "title": "A single author's institutional affiliation", - "pattern": "^(.*)$" - } - } - } - }, - "summary": { - "$id": "#/properties/summary", - "type": "string", - "title": "A one-line summary of the model", - "pattern": "^(.*)$" - }, - "description": { - "$id": "#/properties/description", - "type": "string", - "title": "A long-form description of the model." - }, - "accuracy": { - "$id": "#/properties/accuracy", - "type": "number", - "title": "Test accuracy of the model", - "default": 0.0, - "minimum": 0.0, - "maximum": 100.0 - }, - "license": { - "$id": "#/properties/license", - "type": "string", - "title": "License of the model", - "default": "Apache-2.0", - "enum": ["AAL", "AFL-3.0", "AGPL-3.0", "APL-1.0", "APSL-2.0", "Against-DRM", "Apache-1.1", "Apache-2.0", "Artistic-2.0", "BSD-2-Clause", "BSD-3-Clause", "BSL-1.0", "BitTorrent-1.1", "CATOSL-1.1", "CC-BY-4.0", "CC-BY-NC-4.0", "CC-BY-SA-4.0", "CC0-1.0", "CDDL-1.0", "CECILL-2.1", "CNRI-Python", "CPAL-1.0", "CUA-OPL-1.0", "DSL", "ECL-2.0", "EFL-2.0", "EPL-1.0", "EPL-2.0", "EUDatagrid", "EUPL-1.1", "Entessa", "FAL-1.3", "Fair", "Frameworx-1.0", "GFDL-1.3-no-cover-texts-no-invariant-sections", "GPL-2.0", "GPL-3.0", "HPND", "IPA", "IPL-1.0", "ISC", "Intel", "LGPL-2.1", "LGPL-3.0", "LO-FR-2.0", "LPL-1.0", "LPL-1.02", "LPPL-1.3c", "MIT", "MPL-1.0", "MPL-1.1", "MPL-2.0", "MS-PL", "MS-RL", "MirOS", "Motosoto", "Multics", "NASA-1.3", "NCSA", "NGPL", "NPOSL-3.0", "NTP", "Naumen", "Nokia", "OCLC-2.0", "ODC-BY-1.0", "ODbL-1.0", "OFL-1.1", "OGL-Canada-2.0", "OGL-UK-1.0", "OGL-UK-2.0", "OGL-UK-3.0", "OGTSL", "OSL-3.0", "PDDL-1.0", "PHP-3.0", "PostgreSQL", "Python-2.0", "QPL-1.0", "RPL-1.5", "RPSL-1.0", "RSCPL", "SISSL", "SPL-1.0", "SimPL-2.0", "Sleepycat", "Talis", "Unlicense", "VSL-1.0", "W3C", "WXwindows", "Watcom-1.0", "Xnet", "ZPL-2.0", "Zlib", "dli-model-use", "geogratis", "hesa-withrights", "localauth-withrights", "met-office-cp", "mitre", "notspecified", "other-at", "other-closed", "other-nc", "other-open", "other-pd", "ukclickusepsi", "ukcrown", "ukcrown-withrights", "ukpsi"] - }, - "script": { - "$id": "#/properties/script", - "type": "array", - "uniqueItems": true, - "minItems": 1, - "title": "ISO 15924 scripts recognized by the model", - "items": { - "$id": "#/properties/script/items", - "type": "string", - "enum": ["Tang", "Xsux", "Xpeo", "Blis", "Ugar", "Egyp", "Brai", "Egyh", "Loma", "Egyd", "Hluw", "Maya", "Sgnw", "Inds", "Mero", "Merc", "Sarb", "Narb", "Roro", "Phnx", "Lydi", "Tfng", "Samr", "Armi", "Hebr", "Palm", "Hatr", "Prti", "Phli", "Phlp", "Phlv", "Avst", "Syrc", "Syrn", "Syrj", "Syre", "Mani", "Mand", "Mong", "Nbat", "Arab", "Aran", "Nkoo", "Adlm", "Thaa", "Orkh", "Hung", "Grek", "Cari", "Lyci", "Copt", "Goth", "Ital", "Runr", "Ogam", "Latn", "Latg", "Latf", "Moon", "Osge", "Cyrl", "Cyrs", "Glag", "Elba", "Perm", "Armn", "Aghb", "Geor", "Geok", "Dupl", "Dsrt", "Bass", "Osma", "Olck", "Wara", "Pauc", "Mroo", "Medf", "Visp", "Shaw", "Plrd", "Jamo", "Bopo", "Hang", "Kore", "Kits", "Teng", "Cirt", "Sara", "Piqd", "Brah", "Sidd", "Khar", "Guru", "Gong", "Gonm", "Mahj", "Deva", "Sylo", "Kthi", "Sind", "Shrd", "Gujr", "Takr", "Khoj", "Mult", "Modi", "Beng", "Tirh", "Orya", "Dogr", "Soyo", "Tibt", "Phag", "Marc", "Newa", "Bhks", "Lepc", "Limb", "Mtei", "Ahom", "Zanb", "Telu", "Gran", "Saur", "Knda", "Taml", "Mlym", "Sinh", "Cakm", "Mymr", "Lana", "Thai", "Tale", "Talu", "Khmr", "Laoo", "Kali", "Cham", "Tavt", "Bali", "Java", "Sund", "Rjng", "Leke", "Batk", "Maka", "Bugi", "Tglg", "Hano", "Buhd", "Tagb", "Qaaa", "Sora", "Lisu", "Lina", "Linb", "Cprt", "Hira", "Kana", "Hrkt", "Jpan", "Nkgb", "Ethi", "Bamu", "Kpel", "Qabx", "Mend", "Afak", "Cans", "Cher", "Hmng", "Yiii", "Vaii", "Wole", "Zsye", "Zinh", "Zmth", "Zsym", "Zxxx", "Zyyy", "Zzzz", "Nshu", "Hani", "Hans", "Hant", "Hanb", "Kitl", "Jurc"] - } - }, - "name": { - "$id": "#/properties/name", - "type": "string", - "title": "Filename of the model", - "pattern": "^(.*)$" - }, - "graphemes": { - "$id": "#/properties/graphemes", - "type": "array", - "title": "Code points recognizable by the model", - "uniqueItems": true, - "minItems": 1, - "items": { - "$id": "#/properties/graphemes/items", - "type": "string", - "pattern": "^(.*)$" - } - } - } -} From c4b26b6c820548dc6cba74c590ad8f4c0635b484 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Sun, 5 Jan 2025 22:35:53 +0100 Subject: [PATCH 5/9] Integrate new model data dir in cli driver --- environment.yml | 1 + environment_cuda.yml | 1 + kraken/kraken.py | 64 +++++++++++++++++++++----------------------- setup.cfg | 1 + 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/environment.yml b/environment.yml index c624a8a21..6e4ca53de 100644 --- a/environment.yml +++ b/environment.yml @@ -33,4 +33,5 @@ dependencies: - pip: - coremltools~=8.1 - htrmopo + - platformdirs - file:. diff --git a/environment_cuda.yml b/environment_cuda.yml index 243a9b7c5..d001181c8 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -34,4 +34,5 @@ dependencies: - pip: - coremltools~=8.1 - htrmopo + - platformdirs - file:. diff --git a/kraken/kraken.py b/kraken/kraken.py index ed01e9a86..065228c3a 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -18,28 +18,28 @@ Command line drivers for recognition functionality. """ -import dataclasses -import logging import os import uuid +import click import shlex +import logging import warnings -from functools import partial -from pathlib import Path -from typing import IO, Any, Callable, Dict, List, Union, cast +import dataclasses -import click from PIL import Image +from pathlib import Path +from itertools import chain +from functools import partial from importlib import resources +from platformdirs import user_data_dir +from typing import IO, Any, Callable, Dict, List, Union, cast from rich import print from rich.tree import Tree from rich.table import Table from rich.console import Group from rich.traceback import install -from rich.logging import RichHandler from rich.markdown import Markdown -from rich.progress import Progress from kraken.lib import log @@ -107,7 +107,7 @@ def binarizer(threshold, zoom, escale, border, perc, range, low, high, input, ou processing_steps=ctx.meta['steps'])) else: form = None - ext = os.path.splitext(output)[1] + ext = Path(output).suffix if ext in ['.jpg', '.jpeg', '.JPG', '.JPEG', '']: form = 'png' if ext: @@ -359,7 +359,6 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty placing their respective outputs in temporary files. """ import glob - import os.path import tempfile from threadpoolctl import threadpool_limits @@ -373,9 +372,8 @@ def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_ty # expand batch inputs if batch_input and suffix: for batch_expr in batch_input: - for in_file in glob.glob(os.path.expanduser(batch_expr), recursive=True): - - input.append((in_file, '{}{}'.format(os.path.splitext(in_file)[0], suffix))) + for in_file in glob.glob(str(Path(batch_expr).expanduser()), recursive=True): + input.append(Path(in_file).with_suffix(suffix)) # parse pdfs if format_type == 'pdf': @@ -515,13 +513,14 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, logger.warning(f'Baseline model ({model}) given but legacy segmenter selected. Forcing to -bl.') boxes = False + model = [Path(m) for m in model] if boxes is False: if not model: model = [SEGMENTATION_DEFAULT_MODEL] ctx.meta['steps'].append(ProcessingStep(id=str(uuid.uuid4()), category='processing', description='Baseline and region segmentation', - settings={'model': [os.path.basename(m) for m in model], + settings={'model': [m.name for m in model], 'text_direction': text_direction})) # first try to find the segmentation models by their given names, then @@ -529,15 +528,16 @@ def segment(ctx, model, boxes, text_direction, scale, maxcolseps, locations = [] for m in model: location = None - search = [m, os.path.join(click.get_app_dir(APP_NAME), m)] + search = chain([m], + Path(user_data_dir('htrmopo')).rglob(str(m)), + Path(click.get_app_dir('kraken')).rglob(str(m))) for loc in search: - if os.path.isfile(loc): + if loc.is_file(): location = loc locations.append(loc) break if not location: - raise click.BadParameter(f'No model for {m} found') - + raise click.BadParameter(f'No model for {str(m)} found') from kraken.lib.vgsl import TorchVGSLModel model = [] @@ -638,11 +638,12 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): nm: Dict[str, models.TorchSeqRecognizer] = {} ign_tags = model.pop('ignore') for k, v in model.items(): - search = [v, - os.path.join(click.get_app_dir(APP_NAME), v)] + search = chain([Path(v)], + Path(user_data_dir('htrmopo')).rglob(v), + Path(click.get_app_dir('kraken')).rglob(v)) location = None for loc in search: - if os.path.isfile(loc): + if loc.is_file(): location = loc break if not location: @@ -669,7 +670,7 @@ def ocr(ctx, model, pad, reorder, base_dir, no_segmentation, text_direction): category='processing', description='Text line recognition', settings={'text_direction': text_direction, - 'models': ' '.join(os.path.basename(v) for v in model.values()), + 'models': ' '.join(Path(v).name for v in model.values()), 'pad': pad, 'bidi_reordering': reorder})) @@ -807,29 +808,24 @@ def get(ctx, model_id): """ Retrieves a model from the repository. """ - import glob - - from htrmopo import get_model, get_description + from htrmopo import get_model + from kraken.repo import get_description from kraken.lib.progress import KrakenDownloadProgressBar try: - desc = get_description(model_id) + get_description(model_id, + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) except ValueError as e: logger.error(e) ctx.exit(1) - print(desc) - if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: - logger.error('Record exists but is not a kraken-compatible model') - ctx.exit(1) - with KrakenDownloadProgressBar() as progress: download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False) model_dir = get_model(model_id, - lambda total, advance: progress.update(download_task, total=total, advance=advance)) - model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iter_dir())) - message(f'Model dir: {model_dir} (model files: {model_candidates})') + callback=lambda total, advance: progress.update(download_task, total=total, advance=advance)) + model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iterdir())) + message(f'Model dir: {model_dir} (model files: {", ".join(x.name for x in model_candidates)})') if __name__ == '__main__': diff --git a/setup.cfg b/setup.cfg index efc7000ba..583875a14 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,7 @@ install_requires = lightning~=2.4.0 torchmetrics>=1.1.0 threadpoolctl~=3.5.0 + platformdirs rich [options.extras_require] From de4d3e777260b11c95c5688bf45da88828f24a9e Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 13 Jan 2025 12:16:02 +0100 Subject: [PATCH 6/9] Proper printing of metric-less models in repo --- kraken/ketos/repo.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/kraken/ketos/repo.py b/kraken/ketos/repo.py index 00e48eed8..0eb529369 100644 --- a/kraken/ketos/repo.py +++ b/kraken/ketos/repo.py @@ -159,21 +159,23 @@ def publish(ctx, metadata, access_token, doi, private, model): if len(base_model := _get_field_list('base model URL')): frontmatter['base_model'] = base_model - # take last metrics field, falling back to accuracy field in model metadata - metrics = {} - if nn.user_metadata.get('metrics', None) is not None: - if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None: - metrics['cer'] = 100 - (val_accuracy * 100) - if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None: - metrics['wer'] = 100 - (val_word_accuracy * 100) - elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None: - metrics['cer'] = 100 - accuracy - frontmatter['metrics'] = metrics software_hints = ['kind=vgsl'] - # some recognition-specific software hints + # take last metrics field, falling back to accuracy field in model metadata if nn.model_type == 'recognition': + metrics = {} + if len(nn.user_metadata.get('metrics', '')): + if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None: + metrics['cer'] = 100 - (val_accuracy * 100) + if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None: + metrics['wer'] = 100 - (val_word_accuracy * 100) + elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None: + metrics['cer'] = 100 - accuracy + frontmatter['metrics'] = metrics + + # some recognition-specific software hints and metrics software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}']) + frontmatter['software_hints'] = software_hints frontmatter['software_name'] = 'kraken' From 889f8bbfcb95bfa5b581dde10527ef249298bc65 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 13 Jan 2025 12:16:32 +0100 Subject: [PATCH 7/9] more robust model desc display in kraken --- kraken/kraken.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/kraken/kraken.py b/kraken/kraken.py index 065228c3a..d1db25eaa 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -711,7 +711,9 @@ def _render_creators(creators): return o def _render_metrics(metrics): - return [f'{k}: {v:.2f}' for k, v in metrics.items()] + if metrics: + return [f'{k}: {v:.2f}' for k, v in metrics.items()] + return '' if metadata_version == 'highest': metadata_version = None @@ -757,12 +759,12 @@ def _render_metrics(metrics): table.add_row('model type', Group(*desc.model_type)) table.add_row('language', Group(*[iso639_3_to_name(x) for x in desc.language])) table.add_row('script', Group(*[iso15924_to_name(x) for x in desc.script])) - table.add_row('keywords', Group(*desc.keywords)) - table.add_row('datasets', Group(*desc.datasets)) - table.add_row('metrics', Group(*_render_metrics(desc.metrics))) - table.add_row('base model', Group(*desc.base_model)) + table.add_row('keywords', Group(*desc.keywords) if desc.keywords else '') + table.add_row('datasets', Group(*desc.datasets) if desc.datasets else '') + table.add_row('metrics', Group(*_render_metrics(desc.metrics)) if desc.metrics else '') + table.add_row('base model', Group(*desc.base_model) if desc.base_model else '') table.add_row('software', desc.software_name) - table.add_row('software_hints', Group(*desc.software_hints)) + table.add_row('software_hints', Group(*desc.software_hints) if desc.software_hints else '') table.add_row('license', desc.license) table.add_row('creators', Group(*_render_creators(desc.creators))) table.add_row('description', Markdown(desc.description)) From 9b0831342940dcfcae6c4bd627abc308314bfe89 Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 13 Jan 2025 12:18:08 +0100 Subject: [PATCH 8/9] tests for new repository wrappers --- tests/test_repo.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/tests/test_repo.py b/tests/test_repo.py index 3b54f099a..f7d3ff29e 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -11,7 +11,7 @@ class TestRepo(unittest.TestCase): """ - Testing interaction with the model repository. + Testing our wrappers around HTRMoPo """ def setUp(self): @@ -33,30 +33,11 @@ def test_get_description(self): Tests fetching the description of a model. """ record = repo.get_description('10.5281/zenodo.8425684') - self.assertEqual(record['doi'], '10.5281/zenodo.8425684') - - def test_get_model(self): - """ - Tests fetching a model. - """ - id = repo.get_model('10.5281/zenodo.8425684', - path=self.temp_model.name) - self.assertEqual(id, 'omnisyr_best.mlmodel') - self.assertEqual((self.temp_path / id).stat().st_size, 16245671) + self.assertEqual(record.doi, '10.5281/zenodo.8425684') def test_prev_record_version_get_description(self): """ Tests fetching the description of a model that has a superseding newer version. """ record = repo.get_description('10.5281/zenodo.6657809') - self.assertEqual(record['doi'], '10.5281/zenodo.6657809') - - def test_prev_record_version_get_model(self): - """ - Tests fetching a model that has a superseding newer version. - """ - id = repo.get_model('10.5281/zenodo.6657809', - path=self.temp_model.name) - self.assertEqual(id, 'HTR-United-Manu_McFrench.mlmodel') - self.assertEqual((self.temp_path / id).stat().st_size, 16176844) - + self.assertEqual(record.doi, '10.5281/zenodo.6657809') From ded52d126c400e04f53f554d8f66a4d92f0c1b4f Mon Sep 17 00:00:00 2001 From: Benjamin Kiessling Date: Mon, 13 Jan 2025 21:44:03 +0100 Subject: [PATCH 9/9] Bump up htrmopo version to 0.3 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 583875a14..0b29c0b5c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,7 +58,7 @@ install_requires = scikit-image~=0.24.0 shapely>=2.0.6,~=2.0.6 pyarrow - htrmopo + htrmopo>=0.3,~=0.3 lightning~=2.4.0 torchmetrics>=1.1.0 threadpoolctl~=3.5.0