Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of new model repository #672

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ jobs:
python -m build --sdist --wheel --outdir dist/ .
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- name: Upload PyPI artifacts to GH storage
uses: actions/upload-artifact@v3
with:
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ dependencies:
- setuptools>=36.6.0,<70.0.0
- pip:
- coremltools~=8.1
- htrmopo
- platformdirs
- file:.
2 changes: 2 additions & 0 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ dependencies:
- setuptools>=36.6.0,<70.0.0
- pip:
- coremltools~=8.1
- htrmopo
- platformdirs
- file:.
227 changes: 165 additions & 62 deletions kraken/ketos/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,94 +18,197 @@

Command line driver for publishing models to the model repository.
"""
import re
import logging
import os

import click

from pathlib import Path
from difflib import get_close_matches

from .util import message

logging.captureWarnings(True)
logger = logging.getLogger('kraken')


def _validate_script(script: str) -> str:
from htrmopo.util import _iso15924
if script not in _iso15924:
return get_close_matches(script, _iso15924.keys())
return script


def _validate_language(language: str) -> str:
from htrmopo.util import _iso639_3
if language not in _iso639_3:
return get_close_matches(language, _iso639_3.keys())
return language


def _validate_license(license: str) -> str:
from htrmopo.util import _licenses
if license not in _licenses:
return get_close_matches(license, _licenses.keys())
return license


def _get_field_list(name,
validation_fn=lambda x: x,
required: bool = False):
values = []
while True:
value = click.prompt(name, default='')
if value:
if (cand := validation_fn(value)) == value:
values.append(value)
else:
message(f'Not a valid {name} value. Did you mean {cand}?')
else:
if click.confirm(f'All `{name}` values added?'):
if required and not values:
message(f'`{name}` is a required field.')
continue
else:
break
else:
continue
return values


@click.command('publish')
@click.pass_context
@click.option('-i', '--metadata', show_default=True,
type=click.File(mode='r', lazy=True), help='Metadata for the '
'model. Will be prompted from the user if not given')
type=click.File(mode='r', lazy=True), help='Model card file for the model.')
@click.option('-a', '--access-token', prompt=True, help='Zenodo access token')
@click.option('-d', '--doi', help='DOI of an existing record to update')
@click.option('-p', '--private/--public', default=False, help='Disables Zenodo '
'community inclusion request. Allows upload of models that will not show '
'up on `kraken list` output')
@click.argument('model', nargs=1, type=click.Path(exists=False, readable=True, dir_okay=False))
def publish(ctx, metadata, access_token, private, model):
def publish(ctx, metadata, access_token, doi, private, model):
"""
Publishes a model on the zenodo model repository.
"""
import json
import yaml
import tempfile

from importlib import resources
from jsonschema import validate
from jsonschema.exceptions import ValidationError
from htrmopo import publish_model, update_model

from kraken import repo
from kraken.lib import models
from kraken.lib.vgsl import TorchVGSLModel
from kraken.lib.progress import KrakenDownloadProgressBar

ref = resources.files('kraken').joinpath('metadata.schema.json')
with open(ref, 'rb') as fp:
schema = json.load(fp)

nn = models.load_any(model)

if not metadata:
author = click.prompt('author')
affiliation = click.prompt('affiliation')
summary = click.prompt('summary')
description = click.edit('Write long form description (training data, transcription standards) of the model here')
accuracy_default = None
# take last accuracy measurement in model metadata
if 'accuracy' in nn.nn.user_metadata and nn.nn.user_metadata['accuracy']:
accuracy_default = nn.nn.user_metadata['accuracy'][-1][1] * 100
accuracy = click.prompt('accuracy on test set', type=float, default=accuracy_default)
script = [
click.prompt(
'script',
type=click.Choice(
sorted(
schema['properties']['script']['items']['enum'])),
show_choices=True)]
license = click.prompt(
'license',
type=click.Choice(
sorted(
schema['properties']['license']['enum'])),
show_choices=True)
metadata = {
'authors': [{'name': author, 'affiliation': affiliation}],
'summary': summary,
'description': description,
'accuracy': accuracy,
'license': license,
'script': script,
'name': os.path.basename(model),
'graphemes': ['a']
}
while True:
try:
validate(metadata, schema)
except ValidationError as e:
message(e.message)
metadata[e.path[-1]] = click.prompt(e.path[-1], type=float if e.schema['type'] == 'number' else str)
continue
break
pub_fn = publish_model

_yaml_delim = r'(?:---|\+\+\+)'
_yaml = r'(.*?)'
_content = r'\s*(.+)$'
_re_pattern = r'^\s*' + _yaml_delim + _yaml + _yaml_delim + _content
_yaml_regex = re.compile(_re_pattern, re.S | re.M)

nn = TorchVGSLModel.load_model(model)

frontmatter = {}
# construct metadata if none is given
if metadata:
frontmatter, content = _yaml_regex.match(metadata.read()).groups()
frontmatter = yaml.safe_load(frontmatter)
else:
metadata = json.load(metadata)
validate(metadata, schema)
metadata['graphemes'] = [char for char in ''.join(nn.codec.c2l.keys())]
with KrakenDownloadProgressBar() as progress:
frontmatter['summary'] = click.prompt('summary')
content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here')

creators = []
message('To stop adding authors, leave the author name field empty.')
while True:
author = click.prompt('author name', default='')
if author:
creators.append({'name': author})
else:
if click.confirm('All authors added?'):
break
else:
continue
affiliation = click.prompt('affiliation', default='')
orcid = click.prompt('orcid', default='')
if affiliation is not None:
creators[-1]['affiliation'] = affiliation
if orcid is not None:
creators[-1]['orcid'] = orcid
if not creators:
raise click.UsageError('The `authors` field is obligatory. Aborting')

frontmatter['authors'] = creators
while True:
license = click.prompt('license')
if (lic := _validate_license(license)) == license:
frontmatter['license'] = license
break
else:
message(f'Not a valid license identifer. Did you mean {lic}?')

message('To stop adding values to the following fields, enter an empty field.')

frontmatter['language'] = _get_field_list('language', _validate_language, required=True)
frontmatter['script'] = _get_field_list('script', _validate_script, required=True)

if len(tags := _get_field_list('tag')):
frontmatter['tags'] = tags + ['kraken_pytorch']
if len(datasets := _get_field_list('dataset URL')):
frontmatter['datasets'] = datasets
if len(base_model := _get_field_list('base model URL')):
frontmatter['base_model'] = base_model

software_hints = ['kind=vgsl']

# take last metrics field, falling back to accuracy field in model metadata
if nn.model_type == 'recognition':
metrics = {}
if len(nn.user_metadata.get('metrics', '')):
if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None:
metrics['cer'] = 100 - (val_accuracy * 100)
if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None:
metrics['wer'] = 100 - (val_word_accuracy * 100)
elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None:
metrics['cer'] = 100 - accuracy
frontmatter['metrics'] = metrics

# some recognition-specific software hints and metrics
software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}'])

frontmatter['software_hints'] = software_hints

frontmatter['software_name'] = 'kraken'
frontmatter['model_type'] = [nn.model_type]

# build temporary directory
with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress:
upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False)
oid = repo.publish_model(model, metadata, access_token, lambda total, advance: progress.update(upload_task, total=total, advance=advance), private)
message('model PID: {}'.format(oid))

model = Path(model).resolve()
tmpdir = Path(tmpdir)
(tmpdir / model.name).resolve().symlink_to(model)
if nn.model_type == 'recognition':
# v0 metadata only supports recognition models
v0_metadata = {
'summary': frontmatter['summary'],
'description': content,
'license': frontmatter['license'],
'script': frontmatter['script'],
'name': model.name,
'graphemes': [char for char in ''.join(nn.codec.c2l.keys())]
}
if frontmatter['metrics']:
v0_metadata['accuracy'] = 100 - metrics['cer']
with open(tmpdir / 'metadata.json', 'w') as fo:
json.dump(v0_metadata, fo)
kwargs = {'model': tmpdir,
'model_card': f'---\n{yaml.dump(frontmatter)}---\n{content}',
'access_token': access_token,
'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance),
'private': private}
if doi:
pub_fn = update_model
kwargs['model_id'] = doi
oid = pub_fn(**kwargs)
message(f'model PID: {oid}')
Loading
Loading