Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augmentation issues #673

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion kraken/ketos/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
"""
import logging
import pathlib
import random
from typing import List
from functools import partial
import warnings

import click
from threadpoolctl import threadpool_limits
import torch

from kraken.lib.default_specs import RECOGNITION_HYPER_PARAMS, RECOGNITION_SPEC
from kraken.lib.exceptions import KrakenInputException
Expand Down Expand Up @@ -390,9 +392,11 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs,
help='Whether to honor fixed splits in binary datasets.')
@click.argument('test_set', nargs=-1, callback=_expand_gt, type=click.Path(exists=False, dir_okay=False))
@click.option('--no-legacy-polygons', show_default=True, default=False, is_flag=True, help='Force disable the legacy polygon extractor.')
@click.option('--sample-percentage', show_default=True, type=click.IntRange(1, 100), default=100,
help='Percentage of the test dataset to use for evaluation.')
def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
threads, reorder, base_dir, normalization, normalize_whitespace,
force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons):
force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons, sample_percentage):
"""
Evaluate on a test set.
"""
Expand Down Expand Up @@ -481,6 +485,7 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,

cer_list = []
wer_list = []
cer_case_insensitive_list=[]

with threadpool_limits(limits=threads):
for p, net in nn.items():
Expand Down Expand Up @@ -511,13 +516,23 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,

# don't encode validation set as the alphabets may not match causing encoding failures
ds.no_encode()

# Randomly sample a percentage of the dataset
if sample_percentage < 100:
dataset_indices = list(range(len(ds)))
sample_size = int(len(ds) * sample_percentage / 100)
sampled_indices = random.sample(dataset_indices, sample_size)
ds = torch.utils.data.Subset(ds, sampled_indices)
logger.info(f'Testing on a random {sample_percentage}% of the dataset ({sample_size} lines).')

ds_loader = DataLoader(ds,
batch_size=batch_size,
num_workers=workers,
pin_memory=pin_ds_mem,
collate_fn=collate_sequences)

test_cer = CharErrorRate()
test_cer_case_insensitive = CharErrorRate()
test_wer = WordErrorRate()

with KrakenProgressBar() as progress:
Expand All @@ -537,6 +552,8 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
algn_pred.extend(algn2)
error += c
test_cer.update(x, y)
# Update case-insensitive CER metric
test_cer_case_insensitive.update(x.lower(), y.lower())
test_wer.update(x, y)

except FileNotFoundError as e:
Expand All @@ -550,12 +567,14 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
progress.update(pred_task, advance=1)

cer_list.append(1.0 - test_cer.compute())
cer_case_insensitive_list.append(1.0 - test_cer_case_insensitive.compute())
wer_list.append(1.0 - test_wer.compute())
confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred)
rep = render_report(p,
chars,
error,
cer_list[-1],
cer_case_insensitive_list[-1],
wer_list[-1],
confusions,
scripts,
Expand Down
59 changes: 43 additions & 16 deletions kraken/lib/dataset/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import dataclasses
import multiprocessing as mp

import os
from torchvision.utils import save_image

from collections import Counter
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, List, Literal, Optional,
Expand Down Expand Up @@ -59,9 +62,9 @@ def __init__(self):
import cv2
cv2.setNumThreads(0)
from albumentations import (Blur, Compose, ElasticTransform,
MedianBlur, MotionBlur, OneOf, Affine,
MedianBlur, MotionBlur, OneOf, SafeRotate,
OpticalDistortion, PixelDropout,
ShiftScaleRotate, ToFloat)
ToFloat)

self._transforms = Compose([
ToFloat(),
Expand All @@ -71,16 +74,45 @@ def __init__(self):
MedianBlur(blur_limit=3, p=0.1),
Blur(blur_limit=3, p=0.1),
], p=0.2),
ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=1, p=0.2),
OneOf([
OpticalDistortion(p=0.3),
ElasticTransform(alpha=64, sigma=25, p=0.1),
Affine(translate_px=(0, 5), rotate=(-3, 3), shear=(-5, 5), p=0.2)
ElasticTransform(alpha=7, sigma=25, p=0.1),
SafeRotate(limit=(-3,3), border_mode=cv2.BORDER_CONSTANT, p=0.2)
], p=0.2),
], p=0.5)

def __call__(self, image):
return self._transforms(image=image)

def __call__(self, image, index):
im = image.permute((1, 2, 0)).numpy()
o = self._transforms(image=im)
im = torch.tensor(o['image'].transpose(2, 0, 1))

"""
Saves augmented images to disk for debugging or inspection if `isSave` is set to True.

**Need improve** - User option to debug, option to set the save folder and exceptions

Parameters:
- isSave (bool): Flag to enable or disable saving images.
- index (int): Image index, used for naming the saved file.

The function creates an 'augmented_images' directory (if not already existing),
saves the image as 'image_{index}.png', and logs the save path.
"""
isSave = False
if isSave:
# Save augmented image using torchvision's save_image
output_dir = "augmented_images"
os.makedirs(output_dir, exist_ok=True)

save_path = os.path.join(
output_dir,
f"image_{index}.png"
)
save_image(im, save_path)
logger.info(f"Saved augmented image to {save_path}")

return im


class ArrowIPCRecognitionDataset(Dataset):
Expand Down Expand Up @@ -268,9 +300,7 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
im = Image.open(io.BytesIO(sample['im']))
im = self.transforms(im)
if self.aug:
im = im.permute((1, 2, 0)).numpy()
o = self.aug(image=im)
im = torch.tensor(o['image'].transpose(2, 0, 1))
im = self.aug(image=im, index=index)
text = self._apply_text_transform(sample)
except Exception:
self.failed_samples.add(index)
Expand Down Expand Up @@ -452,9 +482,8 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}')
self._im_mode.value = im_mode
if self.aug:
im = im.permute((1, 2, 0)).numpy()
o = self.aug(image=im)
im = torch.tensor(o['image'].transpose(2, 0, 1))
im = self.aug(image=im, index=index)

return {'image': im, 'target': item[1]}
except Exception:
self.failed_samples.add(index)
Expand Down Expand Up @@ -637,9 +666,7 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
logger.info(f'Upgrading "im_mode" from {self._im_mode.value} to {im_mode}')
self._im_mode.value = im_mode
if self.aug:
im = im.permute((1, 2, 0)).numpy()
o = self.aug(image=im)
im = torch.tensor(o['image'].transpose(2, 0, 1))
im = self.aug(image=im, index=index)
return {'image': im, 'target': item[1]}
except Exception:
raise
Expand Down
2 changes: 2 additions & 0 deletions kraken/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def render_report(model: str,
chars: int,
errors: int,
char_accuracy: float,
char_CI_accucary: float, #Case insensitive
word_accuracy: float,
char_confusions: 'Counter',
scripts: 'Counter',
Expand Down Expand Up @@ -278,6 +279,7 @@ def render_report(model: str,
'chars': chars,
'errors': errors,
'character_accuracy': char_accuracy * 100,
'character_CI_accucary': char_CI_accucary * 100,
'word_accuracy': word_accuracy * 100,
'insertions': sum(insertions.values()),
'deletions': deletions,
Expand Down
1 change: 1 addition & 0 deletions kraken/templates/report
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
{{ report.chars }} Characters
{{ report.errors }} Errors
{{ '%0.2f'| format(report.character_accuracy) }}% Character Accuracy
{{ '%0.2f'| format(report.character_CI_accucary) }}% Character Accuracy (Case-insensitive)
{{ '%0.2f'| format(report.word_accuracy) }}% Word Accuracy

{{ report.insertions }} Insertions
Expand Down