Skip to content

Commit

Permalink
Finalize integration of reading order model
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Apr 22, 2024
1 parent 26d9952 commit d78a461
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 157 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
run: |
poetry run python benchmark/layout.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark text
run: |
poetry run python benchmark/ordering.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/ordering_bench/results.json --bench_type ordering
43 changes: 31 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | -- |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | -- |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) |

# Installation

Expand Down Expand Up @@ -65,11 +65,11 @@ pip install streamlit
surya_gui
```

Pass the `--math` command line argument to use the math detection model instead of the default model. This will detect math better, but will be worse at everything else.
Pass the `--math` command line argument to use the math text detection model instead of the default model. This will detect math better, but will be worse at everything else.

## OCR (text recognition)

You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
This command will write out a json file with the detected text and bboxes:

```shell
surya_ocr DATA_PATH --images --langs hi,en
Expand Down Expand Up @@ -117,7 +117,7 @@ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec

## Text line detection

You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes.
This command will write out a json file with the detected bboxes.

```shell
surya_detect DATA_PATH --images
Expand Down Expand Up @@ -162,7 +162,7 @@ predictions = batch_text_detection([image], model, processor)

## Layout analysis

You can detect the layout of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected layout.
This command will write out a json file with the detected layout.

```shell
surya_layout DATA_PATH --images
Expand Down Expand Up @@ -209,7 +209,7 @@ layout_predictions = batch_layout_detection([image], model, processor, line_pred

## Reading order

You can detect the reading order of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected reading order and layout.
This command will write out a json file with the detected reading order and layout.

```shell
surya_order DATA_PATH --images
Expand All @@ -224,15 +224,14 @@ The `results.json` file will contain a json dictionary where the keys are the in

- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `confidence` - the confidence of the model in the detected text (0-1). This is currently not very reliable.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Text`, `Title`.
- `position` - the position in the reading order of the bbox, starting from 0.
- `label` - the label for the bbox. See the layout section of the documentation for a list of potential labels.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

**Performance tips**

Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `360MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 11GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.

### From python

Expand Down Expand Up @@ -357,6 +356,16 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/
- Precision - how well the predicted bboxes cover ground truth bboxes
- Recall - how well ground truth bboxes cover predicted bboxes

## Reading Order

75% mean accuracy, and .14 seconds per image on an A6000 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.

**Methodology**

I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.

The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.

## Running your own benchmarks

You can benchmark the performance of surya on your machine.
Expand Down Expand Up @@ -403,6 +412,16 @@ python benchmark/layout.py
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Reading Order**

```
python benchmark/ordering.py
```

- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

# Training

Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
Expand All @@ -411,7 +430,7 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m

# Commercial usage

The text detection, layout analysis, and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
All models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.

If you want to remove the GPL license requirements for inference or use the weights commercially over the revenue limit, please contact me at [email protected] for dual licensing.

Expand Down
111 changes: 0 additions & 111 deletions benchmark/order.py

This file was deleted.

3 changes: 1 addition & 2 deletions ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
def order_detection(img) -> (Image.Image, OrderResult):
_, layout_pred = layout_detection(img)
bboxes = [l.bbox for l in layout_pred.bboxes]
labels = [l.label for l in layout_pred.bboxes]
pred = batch_ordering([img], [bboxes], order_model, order_processor, labels=[labels])[0]
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
Expand Down
5 changes: 1 addition & 4 deletions reading_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,11 @@ def main():
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
bboxes = []
labels = []
for layout_pred in layout_predictions:
bbox = [l.bbox for l in layout_pred.bboxes]
label = [l.label for l in layout_pred.bboxes]
bboxes.append(bbox)
labels.append(label)

order_predictions = batch_ordering(images, bboxes, model, processor, labels=labels)
order_predictions = batch_ordering(images, bboxes, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

Expand Down
8 changes: 8 additions & 0 deletions scripts/verify_benchmark_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def verify_rec(data):
raise ValueError("Scores do not meet the required threshold")


def verify_order(data):
score = data["mean_accuracy"]
if score <= 0.9:
raise ValueError("Scores do not meet the required threshold")


def verify_scores(file_path, bench_type):
with open(file_path, 'r') as file:
data = json.load(file)
Expand All @@ -31,6 +37,8 @@ def verify_scores(file_path, bench_type):
verify_rec(data)
elif bench_type == "layout":
verify_layout(data)
elif bench_type == "ordering":
verify_order(data)
else:
raise ValueError("Invalid benchmark type")

Expand Down
Binary file added static/images/nyt_order.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/textbook_order.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 5 additions & 2 deletions surya/model/ordering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from surya.settings import settings


def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)

decoder_config = vars(config.decoder)
Expand All @@ -24,8 +24,11 @@ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)

model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config)
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
assert isinstance(model.decoder, MBartOrder)
assert isinstance(model.encoder, VariableDonutSwinModel)

model = model.to(device)
model = model.eval()
print(f"Loading reading order model {checkpoint} on device {device} with dtype {dtype}")
return model
30 changes: 4 additions & 26 deletions surya/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,11 @@ def rank_elements(arr):
return rank


def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, labels: Optional[List[List[str]]] = None) -> List[OrderResult]:
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(bboxes)
batch_size = get_batch_size()

if labels is not None:
assert len(labels) == len(images)
for l, b in zip(labels, bboxes):
assert len(l) == len(b)

images = [image.convert("RGB") for image in images]

output_order = []
Expand Down Expand Up @@ -78,11 +73,12 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process

last_tokens = []
last_token_mask = []
min_val = torch.finfo(model.dtype).min
for j in range(logits.shape[0]):
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
new_logits = logits[j, -1].clone()
new_logits[batch_predictions[j]] = -1e9 # Mask out already predicted tokens, we can only predict each token once
new_logits[label_count:] = -1e9 # Mask out all logit positions above the number of bboxes
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
pred = int(torch.argmax(new_logits, dim=-1).item())

# Add one to avoid colliding with the 1000 height/width token for bboxes
Expand Down Expand Up @@ -119,24 +115,6 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
for box_idx in range(len(row_bboxes)):
ranks[row_pred[box_idx]] = box_idx

if labels is not None:
# This is to force headers/footers into the proper order
row_label = labels[i+j]
combined = [[i, bbox, label, rank] for i, (bbox, label, rank) in enumerate(zip(row_bboxes, row_label, ranks))]
combined = sorted(combined, key=lambda x: x[3])

sorted_boxes = ([row for row in combined if row[2] == "Page-header"] +
[row for row in combined if row[2] not in ["Page-header", "Page-footer"]] +
[row for row in combined if row[2] == "Page-footer"])

# Re-rank after sorting
for rank, row in enumerate(sorted_boxes):
row[3] = rank

sorted_boxes = sorted(sorted_boxes, key=lambda x: x[0])
row_bboxes = [row[1] for row in sorted_boxes]
ranks = [row[3] for row in sorted_boxes]

order_boxes = []
for row_bbox, rank in zip(row_bboxes, ranks):
order_box = OrderBox(
Expand Down

0 comments on commit d78a461

Please sign in to comment.