diff --git a/kraken/ketos/recognition.py b/kraken/ketos/recognition.py index c1fb57482..2fe0d7b98 100644 --- a/kraken/ketos/recognition.py +++ b/kraken/ketos/recognition.py @@ -20,12 +20,14 @@ """ import logging import pathlib +import random from typing import List from functools import partial import warnings import click from threadpoolctl import threadpool_limits +import torch from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC from kraken.lib.exceptions import KrakenInputException @@ -390,9 +392,11 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs, help='Whether to honor fixed splits in binary datasets.') @click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False)) @click.option('--no-legacy-polygons', show_default=True, default=False, is_flag=True, help='Force disable the legacy polygon extractor.') +@click.option('--sample-percentage', show_default=True, type=click.IntRange(1, 100), default=100, + help='Percentage of the test dataset to use for evaluation.') def test(ctx, batch_size, model, evaluation_files, device, pad, workers, threads, reorder, base_dir, normalization, normalize_whitespace, - force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons): + force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons, sample_percentage): """ Evaluate on a test set. """ @@ -481,6 +485,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, cer_list = [] wer_list = [] + cer_case_insensitive_list=[] with threadpool_limits(limits=threads): for p, net in nn.items(): @@ -511,6 +516,15 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, # don't encode validation set as the alphabets may not match causing encoding failures ds.no_encode() + + # Randomly sample a percentage of the dataset + if sample_percentage < 100: + dataset_indices = list(range(len(ds))) + sample_size = int(len(ds) * sample_percentage / 100) + sampled_indices = random.sample(dataset_indices, sample_size) + ds = torch.utils.data.Subset(ds, sampled_indices) + logger.info(f'Testing on a random {sample_percentage}% of the dataset ({sample_size} lines).') + ds_loader = DataLoader(ds, batch_size=batch_size, num_workers=workers, @@ -518,6 +532,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, collate_fn=collate_sequences) test_cer = CharErrorRate() + test_cer_case_insensitive = CharErrorRate() test_wer = WordErrorRate() with KrakenProgressBar() as progress: @@ -537,6 +552,8 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, algn_pred.extend(algn2) error += c test_cer.update(x, y) + # Update case-insensitive CER metric + test_cer_case_insensitive.update(x.lower(), y.lower()) test_wer.update(x, y) except FileNotFoundError as e: @@ -550,12 +567,14 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers, progress.update(pred_task, advance=1) cer_list.append(1.0 - test_cer.compute()) + cer_case_insensitive_list.append(1.0 - test_cer_case_insensitive.compute()) wer_list.append(1.0 - test_wer.compute()) confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) rep = render_report(p, chars, error, cer_list[-1], + cer_case_insensitive_list[-1], wer_list[-1], confusions, scripts, diff --git a/kraken/lib/dataset/recognition.py b/kraken/lib/dataset/recognition.py index 4a1d6e64d..2f5693140 100644 --- a/kraken/lib/dataset/recognition.py +++ b/kraken/lib/dataset/recognition.py @@ -24,6 +24,9 @@ import dataclasses import multiprocessing as mp +import os +from torchvision.utils import save_image + from collections import Counter from functools import partial from typing import (TYPE_CHECKING, Any, Callable, List, Literal, Optional, @@ -59,9 +62,9 @@ def __init__(self): import cv2 cv2.setNumThreads(0) from albumentations import (Blur, Compose, ElasticTransform, - MedianBlur, MotionBlur, OneOf, Affine, + MedianBlur, MotionBlur, OneOf, SafeRotate, OpticalDistortion, PixelDropout, - ShiftScaleRotate, ToFloat) + ToFloat) self._transforms = Compose([ ToFloat(), @@ -71,16 +74,45 @@ def __init__(self): MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1), ], p=0.2), - ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=1, p=0.2), OneOf([ OpticalDistortion(p=0.3), - ElasticTransform(alpha=64, sigma=25, p=0.1), - Affine(translate_px=(0, 5), rotate=(-3, 3), shear=(-5, 5), p=0.2) + ElasticTransform(alpha=7, sigma=25, p=0.1), + SafeRotate(limit=(-3,3), border_mode=cv2.BORDER_CONSTANT, p=0.2) ], p=0.2), ], p=0.5) - def __call__(self, image): - return self._transforms(image=image) + + def __call__(self, image, index): + im = image.permute((1, 2, 0)).numpy() + o = self._transforms(image=im) + im = torch.tensor(o['image'].transpose(2, 0, 1)) + + """ + Saves augmented images to disk for debugging or inspection if `isSave` is set to True. + + **Need improve** - User option to debug, option to set the save folder and exceptions + + Parameters: + - isSave (bool): Flag to enable or disable saving images. + - index (int): Image index, used for naming the saved file. + + The function creates an 'augmented_images' directory (if not already existing), + saves the image as 'image_{index}.png', and logs the save path. + """ + isSave = False + if isSave: + # Save augmented image using torchvision's save_image + output_dir = "augmented_images" + os.makedirs(output_dir, exist_ok=True) + + save_path = os.path.join( + output_dir, + f"image_{index}.png" + ) + save_image(im, save_path) + logger.info(f"Saved augmented image to {save_path}") + + return im class ArrowIPCRecognitionDataset(Dataset): @@ -268,9 +300,7 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: im = Image.open(io.BytesIO(sample['im'])) im = self.transforms(im) if self.aug: - im = im.permute((1, 2, 0)).numpy() - o = self.aug(image=im) - im = torch.tensor(o['image'].transpose(2, 0, 1)) + im = self.aug(image=im, index=index) text = self._apply_text_transform(sample) except Exception: self.failed_samples.add(index) @@ -452,9 +482,8 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}') self._im_mode.value = im_mode if self.aug: - im = im.permute((1, 2, 0)).numpy() - o = self.aug(image=im) - im = torch.tensor(o['image'].transpose(2, 0, 1)) + im = self.aug(image=im, index=index) + return {'image': im, 'target': item[1]} except Exception: self.failed_samples.add(index) @@ -637,9 +666,7 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}') self._im_mode.value = im_mode if self.aug: - im = im.permute((1, 2, 0)).numpy() - o = self.aug(image=im) - im = torch.tensor(o['image'].transpose(2, 0, 1)) + im = self.aug(image=im, index=index) return {'image': im, 'target': item[1]} except Exception: raise diff --git a/kraken/serialization.py b/kraken/serialization.py index 9666e344c..e5daa6444 100644 --- a/kraken/serialization.py +++ b/kraken/serialization.py @@ -248,6 +248,7 @@ def render_report(model: str, chars: int, errors: int, char_accuracy: float, + char_CI_accucary: float, #Case insensitive word_accuracy: float, char_confusions: 'Counter', scripts: 'Counter', @@ -278,6 +279,7 @@ def render_report(model: str, 'chars': chars, 'errors': errors, 'character_accuracy': char_accuracy * 100, + 'character_CI_accucary': char_CI_accucary * 100, 'word_accuracy': word_accuracy * 100, 'insertions': sum(insertions.values()), 'deletions': deletions, diff --git a/kraken/templates/report b/kraken/templates/report index abd81fbb2..09ee14cda 100644 --- a/kraken/templates/report +++ b/kraken/templates/report @@ -3,6 +3,7 @@ {{ report.chars }} Characters {{ report.errors }} Errors {{ '%0.2f'| format(report.character_accuracy) }}% Character Accuracy +{{ '%0.2f'| format(report.character_CI_accucary) }}% Character Accuracy (Case-insensitive) {{ '%0.2f'| format(report.word_accuracy) }}% Word Accuracy {{ report.insertions }} Insertions