Skip to content

Commit

Permalink
Add gpu message
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Jan 29, 2025
1 parent 0ccb99c commit 17a5dfe
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
14 changes: 14 additions & 0 deletions olmocr/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def check_sglang_version():
logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.")
sys.exit(1)

def check_torch_gpu_available(min_gpu_memory: int=8 * 1024**3):
try:
import torch
except:
logger.error("Pytorch must be installed, visit https://pytorch.org/ for installation instructions")
raise

try:
gpu_memory = torch.cuda.get_device_properties(0)
assert gpu_memory >= min_gpu_memory
except:
logger.error(f"Torch was not able to find a GPU with at least {min_gpu_memory // (1024 ** 3)} GB of RAM.")
raise


if __name__ == "__main__":
check_poppler_version()
Expand Down
4 changes: 2 additions & 2 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from olmocr.filter.filter import PdfFilter, Language
from olmocr.prompts import build_finetuning_prompt, PageResponse
from olmocr.prompts.anchor import get_anchor_text
from olmocr.check import check_poppler_version, check_sglang_version
from olmocr.check import check_poppler_version, check_sglang_version, check_torch_gpu_available
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.version import VERSION

Expand Down Expand Up @@ -470,7 +470,6 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
async def sglang_server_task(args, semaphore):
model_name_or_path = args.model


# if "://" in model_name_or_path:
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
Expand Down Expand Up @@ -902,6 +901,7 @@ async def main():

check_poppler_version()
check_sglang_version()
check_torch_gpu_available()

# Create work queue
if args.workspace.startswith("s3://"):
Expand Down

0 comments on commit 17a5dfe

Please sign in to comment.