Skip to content

Commit

Permalink
Merge pull request #28 from allenai/amanr/code_documentation
Browse files Browse the repository at this point in the history
Resolved Git checks and updated readme
  • Loading branch information
jakep-allenai authored Feb 10, 2025
2 parents 9bf3d35 + f57c6f3 commit e627842
Show file tree
Hide file tree
Showing 18 changed files with 73 additions and 64 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

- Fixed git checks
39 changes: 17 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# olmOCR

Toolkit for training language models to work with PDF documents in the wild.
A toolkit for training language models to work with PDF documents in the wild.


<img src="https://github.com/user-attachments/assets/d70c8644-3e64-4230-98c3-c52fddaeccb6" alt="olmOCR Logo" width="300"/>
<br/>

Online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)
Try the online demo: [https://olmocr.allen.ai/](https://olmocr.allen.ai/)

What is included:
- A prompting strategy to get really good natural text parsing using ChatGPT 4o - [buildsilver.py](https://github.com/allenai/olmocr/blob/main/olmocr/data/buildsilver.py)
Expand All @@ -22,15 +22,15 @@ Requirements:
- Recent NVIDIA GPU (tested on RTX 4090, L40S, A100, H100)
- 30GB of free disk space

You will need to install poppler-utils and some additional fonts as a prerequisite. olmOCR uses poppler to render its PDF images.
You will need to install poppler-utils and additional fonts for rendering PDF images.

Linux Ubuntu/Debian
Install dependencies (Ubuntu/Debian)
```bash
sudo apt-get update
sudo apt-get install poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
```

Set up a conda environment, then clone and install the olmocr package
Set up a conda environment and install olmocr
```bash
conda create -n olmocr python=3.11
conda activate olmocr
Expand All @@ -40,45 +40,40 @@ cd olmocr
pip install -e .
```

Finally, make sure you have sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) installed if you want to run inference on your own GPU.
Install sglang with [flashinfer](https://github.com/flashinfer-ai/flashinfer) if you want to run inference on GPU.
```bash
pip install sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```

**BETA TESTER NOTE:**

If you are a beta tester, you will need to login using the hugging-face CLI
to make sure you have access to https://huggingface.co/allenai/olmocr-preview

`huggingface-cli login`

If you’re a beta tester, log in with Hugging Face CLI to access (olmOCR)[https://huggingface.co/allenai/olmocr-preview] preview model:
``` bash
huggingface-cli login
```
### Local Usage Example

The easiest way to try out olmOCR on one or two PDFs is to check out the [web demo](https://olmocr.allen.ai/).

Once you are ready to run locally, a local GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang)
under the hood.

This command will convert one PDF into a directory called `localworkspace`:
For quick testing, try the [web demo](https://olmocr.allen.ai/). To run locally, a GPU is required, as inference is powered by [sglang](https://github.com/sgl-project/sglang) under the hood.
Convert a Single PDF:
```bash
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/horribleocr.pdf # will convert one PDF into a directory called `localworkspace`
```

You can also bulk convert many PDFS with a glob pattern:
Convert Multiple PDFs:
```bash
python -m olmocr.pipeline ./localworkspace --pdfs tests/gnarly_pdfs/*.pdf
```

#### Viewing Results

Once that finishes, output is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.
Extracted text is stored as [Dolma](https://github.com/allenai/dolma)-style JSONL inside of the `./localworkspace/results` directory.

```bash
cat localworkspace/results/output_*.jsonl
```

You can view your documents side-by-side with the original PDF renders using the `dolmaviewer` command.
View results side-by-side with the original PDFs (uses `dolmaviewer` command):

```bash
python -m olmocr.viewer.dolmaviewer localworkspace/results/output_*.jsonl
Expand Down Expand Up @@ -106,7 +101,7 @@ Now on any subsequent nodes, just run this and they will start grabbing items fr
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace
```

If you are at AI2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
If you are at Ai2 and want to linearize millions of PDFs efficiently using [beaker](https://www.beaker.org), just add the `--beaker`
flag. This will prepare the workspace on your local machine, and then launch N GPU workers in the cluster to start
converting PDFs.

Expand Down
9 changes: 5 additions & 4 deletions olmocr/data/buildsilverdatasummary.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import argparse
import collections
import csv
import json
import os
import random
import re
import sqlite3
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Optional
from urllib.parse import urlparse

from tqdm import tqdm


def parse_pdf_hash(pretty_pdf_path: str) -> str:
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
match = re.match(pattern, pretty_pdf_path)
if match:
Expand Down Expand Up @@ -58,7 +59,7 @@ def cache_athena_csv_to_db(athena_csv_path: str) -> str:
return db_path


def get_uri_from_db(db_path: str, pdf_hash: str) -> str:
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
Expand Down Expand Up @@ -154,7 +155,7 @@ def main():
for cid, uri, domain in all_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""])

domain_counter = collections.Counter()
domain_counter: Counter[str] = Counter()
for _, _, domain in all_rows:
if domain:
domain_counter[domain] += 1
Expand Down
6 changes: 3 additions & 3 deletions olmocr/data/renderpdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import io
import subprocess
from typing import List

from PIL import Image

Expand All @@ -25,12 +26,11 @@ def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[

# Parse the output to find MediaBox
output = result.stdout
media_box = None

for line in output.splitlines():
if "MediaBox" in line:
media_box = line.split(":")[1].strip().split()
media_box = [float(x) for x in media_box]
media_box_str: List[str] = line.split(":")[1].strip().split()
media_box: List[float] = [float(x) for x in media_box_str]
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])

raise ValueError("MediaBox not found in the PDF info.")
Expand Down
4 changes: 2 additions & 2 deletions olmocr/data/runopenaibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def get_estimated_space_usage(folder_path):


def get_next_work_item(folder_path):
all_states = get_state(folder_path)
all_states = [s for s in all_states.values() if s["state"] not in FINISHED_STATES]
all_states = list(get_state(folder_path).values())
all_states = [s for s in all_states if s["state"] not in FINISHED_STATES]
all_states.sort(key=lambda s: s["last_checked"])

return all_states[0] if len(all_states) > 0 else None
Expand Down
10 changes: 8 additions & 2 deletions olmocr/eval/buildelo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ class Comparison:

@property
def comparison_a_method(self):
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path).group(1)
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_a_path}")

@property
def comparison_b_method(self):
return re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path).group(1)
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_b_path}")


def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
Expand Down
4 changes: 2 additions & 2 deletions olmocr/eval/runeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def list_jsonl_files(path: str) -> list:
# Returns the average Levenshtein distance match between the data
def process_jsonl_file(jsonl_file, gold_data, comparer):
page_data = {}
total_alignment_score = 0
char_weighted_alignment_score = 0
total_alignment_score: float = 0.0
char_weighted_alignment_score: float = 0.0
total_pages = 0
total_chars = 0
total_errors = 0
Expand Down
5 changes: 3 additions & 2 deletions olmocr/eval/scoreelo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import csv
import re
from collections import defaultdict
from typing import Any, DefaultDict
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit

import requests
import requests # type: ignore


def fetch_review_page_html(url):
Expand Down Expand Up @@ -108,7 +109,7 @@ def build_comparison_report(entries_dict, datastore):
comparisons[(A, B)] = [A_wins, B_wins],
where A < B lexicographically in that tuple.
"""
comparisons = defaultdict(lambda: [0, 0])
comparisons: DefaultDict[Any, list[int]] = defaultdict(lambda: [0, 0])

for entry_id, vote in datastore.items():
if entry_id not in entries_dict:
Expand Down
5 changes: 3 additions & 2 deletions olmocr/filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import subprocess
from collections import Counter
from typing import List

from lingua import Language, LanguageDetectorBuilder
from pypdf import PdfReader
Expand Down Expand Up @@ -142,7 +143,7 @@ def process_pdf(s3_path):

# Load the list of S3 paths with a progress bar
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
s3_work_paths: List[str] = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))

# Initialize the PDF filter
filter = PdfFilter(
Expand Down Expand Up @@ -173,7 +174,7 @@ def process_pdf(s3_path):

while pending_futures:
# Wait for the next future to complete
done, _ = wait(
done, _ = wait( # type: ignore
pending_futures.keys(),
timeout=0.1,
return_when=FIRST_COMPLETED,
Expand Down
12 changes: 6 additions & 6 deletions olmocr/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from collections import defaultdict, deque
from typing import Dict
from typing import Any, Deque, Dict, List, Set


class MetricsKeeper:
Expand All @@ -15,7 +15,7 @@ def __init__(self, window=60 * 5):
self.window = window # Time window in seconds
self.start_time = time.time() # Timestamp when MetricsKeeper was created
self.total_metrics = defaultdict(int) # Cumulative metrics since start
self.window_metrics = deque() # Deque to store (timestamp, metrics_dict)
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
self.window_sum = defaultdict(int) # Sum of metrics within the window

def add_metrics(self, **kwargs):
Expand Down Expand Up @@ -108,16 +108,16 @@ async def get_status_table(self) -> str:
"""
async with self.lock:
# Determine all unique states across all workers
all_states = set()
all_states: Set[str] = set()
for states in self.worker_status.values():
all_states.update(states.keys())
all_states = sorted(all_states)
sorted_states: List[str] = sorted(all_states)

headers = ["Worker ID"] + all_states
headers = ["Worker ID"] + sorted_states # type: ignore
rows = []
for worker_id, states in sorted(self.worker_status.items()):
row = [str(worker_id)]
for state in all_states:
for state in sorted_states:
count = states.get(state, 0)
row.append(str(count))
rows.append(row)
Expand Down
4 changes: 2 additions & 2 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
)

image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text)
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) # type: ignore
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
Expand Down Expand Up @@ -659,7 +659,7 @@ async def metrics_reporter(work_queue):


def submit_beaker_job(args):
from beaker import (
from beaker import ( # type: ignore
Beaker,
Constraints,
EnvVar,
Expand Down
18 changes: 9 additions & 9 deletions olmocr/prompts/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_anchor_text(

scores = {label: get_document_coherency(text) for label, text in options.items()}

best_option_label = max(scores, key=scores.get)
best_option_label = max(scores, key=scores.get) # type: ignore
best_option = options[best_option_label]

print(f"topcoherency chosen: {best_option_label}")
Expand Down Expand Up @@ -194,7 +194,7 @@ def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool:
union(i, j)

# Group images by their root parent
groups = {}
groups: dict[int, list[int]] = {}
for i in range(n):
root = find(i)
groups.setdefault(root, []).append(i)
Expand Down Expand Up @@ -268,21 +268,21 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:

# Process text elements
text_strings = []
for element in report.text_elements:
if len(element.text.strip()) == 0:
for element in report.text_elements: # type: ignore
if len(element.text.strip()) == 0: # type: ignore
continue

element_text = _cleanup_element_text(element.text)
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n"
element_text = _cleanup_element_text(element.text) # type: ignore
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
text_strings.append((element, text_str))

# Combine all elements with their positions for sorting
all_elements = []
all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
for elem, s in image_strings:
position = (elem.bbox.x0, elem.bbox.y0)
all_elements.append(("image", elem, s, position))
for elem, s in text_strings:
position = (elem.x, elem.y)
position = (elem.x, elem.y) # type: ignore
all_elements.append(("text", elem, s, position))

# Calculate total length
Expand Down Expand Up @@ -311,7 +311,7 @@ def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
max_x_text = max(text_elements, key=lambda e: e.x)
min_y_text = min(text_elements, key=lambda e: e.y)
max_y_text = max(text_elements, key=lambda e: e.y)
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore

# Keep track of element IDs to prevent duplication
selected_element_ids = set()
Expand Down
4 changes: 2 additions & 2 deletions olmocr/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from urllib.parse import urlparse

import boto3
import requests
import requests # type: ignore
import zstandard as zstd
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
Expand Down Expand Up @@ -58,7 +58,7 @@ def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
return matched

Expand Down
Loading

0 comments on commit e627842

Please sign in to comment.