Skip to content

Commit

Permalink
Add in reading order benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Apr 19, 2024
1 parent 4a6a6e3 commit 26d9952
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 4 deletions.
79 changes: 79 additions & 0 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import argparse
import collections
import copy
import json

from surya.benchmark.metrics import precision_recall
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.ordering import batch_ordering
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy
import os
import time
from tabulate import tabulate
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya reading order model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()

model = load_model()
processor = load_processor()

pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = [i.convert("RGB") for i in images]
bboxes = list(dataset["bboxes"])

start = time.time()
order_predictions = batch_ordering(images, bboxes, model, processor)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
mean_accuracy = 0
for idx, order_pred in enumerate(order_predictions):
row = dataset[idx]
pred_labels = [str(l.position) for l in order_pred.bboxes]
labels = row["labels"]
accuracy = rank_accuracy(pred_labels, labels)
mean_accuracy += accuracy
page_results = {
"accuracy": accuracy,
"box_count": len(labels)
}

page_metrics[idx] = page_results

mean_accuracy /= len(order_predictions)

out_data = {
"time": surya_time,
"mean_accuracy": mean_accuracy,
"page_metrics": page_metrics
}

with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(out_data, f, indent=4)

print(f"Mean accuracy is {mean_accuracy:.2f}.")
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
print("Mean accuracy is the % of correct ranking pairs.")
print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()
21 changes: 20 additions & 1 deletion surya/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,23 @@ def mean_coverage(preds, references):
if len(coverages) == 0:
return 0
coverage = sum(coverages) / len(coverages)
return {"coverage": coverage}
return {"coverage": coverage}


def rank_accuracy(preds, references):
# Preds and references need to be aligned so each position refers to the same bbox
pairs = []
for i, pred in enumerate(preds):
for j, pred2 in enumerate(preds):
if i == j:
continue
pairs.append((i, j, pred > pred2))

# Find how many of the prediction rankings are correct
correct = 0
for i, ref in enumerate(references):
for j, ref2 in enumerate(references):
if (i, j, ref > ref2) in pairs:
correct += 1

return correct / len(pairs)
8 changes: 5 additions & 3 deletions surya/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
def get_batch_size():
batch_size = settings.ORDER_BATCH_SIZE
if batch_size is None:
batch_size = 4
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 4
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 32
return batch_size
Expand Down Expand Up @@ -91,9 +91,11 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
last_token_mask.append([0])
batch_predictions[j].append(pred)
done[j] = True
else:
elif len(batch_predictions[j]) < label_count - 1:
last_token_mask.append([1])
batch_predictions[j].append(pred) # Get rank prediction for given position
else:
last_token_mask.append([0])

# Break when we finished generating all sequences
if all(done):
Expand Down

0 comments on commit 26d9952

Please sign in to comment.