Skip to content

Commit

Permalink
Generalize and refactor VLM pipeline and models
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <[email protected]>
  • Loading branch information
cau-git committed Feb 25, 2025
1 parent 1c75b52 commit 1cba96e
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 96 deletions.
6 changes: 3 additions & 3 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []


class DocTagsPrediction(BaseModel):
tag_string: str = ""
class VlmPrediction(BaseModel):
text: str = ""


class ContainerElement(
Expand Down Expand Up @@ -201,7 +201,7 @@ class PagePredictions(BaseModel):
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None
doctags: Optional[DocTagsPrediction] = None
vlm_response: Optional[VlmPrediction] = None


PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
Expand Down
42 changes: 35 additions & 7 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,44 @@ def repo_cache_folder(self) -> str:
)


class SmolDoclingOptions(BaseModel):
question: str = "Convert this page to docling."
class BaseVlmOptions(BaseModel):
kind: str
prompt: str


class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"


class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"

repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False

response_format: ResponseFormat

@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")


smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
)

granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
response_format=ResponseFormat.MARKDOWN,
)


# Define an enum for the backend options
class PdfBackend(str, Enum):
Expand Down Expand Up @@ -300,13 +332,11 @@ class PaginatedPipelineOptions(PipelineOptions):

class VlmPipelineOptions(PaginatedPipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_vlm: bool = True # True: perform inference of Visual Language Model

force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
vlm_options: Union[SmolDoclingOptions,] = Field(SmolDoclingOptions())
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options


class PdfPipelineOptions(PaginatedPipelineOptions):
Expand Down Expand Up @@ -337,8 +367,6 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
Field(discriminator="kind"),
] = smolvlm_picture_description

vlm_options: Union[SmolDoclingOptions,] = Field(SmolDoclingOptions())

images_scale: float = 1.0
generate_page_images: bool = False
generate_picture_images: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from pathlib import Path
from typing import Iterable, List, Optional

from docling.datamodel.base_models import DocTagsPrediction, Page
from transformers import AutoModelForVision2Seq

from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
SmolDoclingOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
Expand All @@ -18,19 +20,19 @@
_log = logging.getLogger(__name__)


class SmolDoclingModel(BasePageModel):

_repo_id: str = "ds4sd/SmolDocling-256M-preview"
class HuggingFaceVlmModel(BasePageModel):

def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: SmolDoclingOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled

self.vlm_options = vlm_options

if self.enabled:
import torch
from transformers import ( # type: ignore
Expand All @@ -42,17 +44,17 @@ def __init__(
device = decide_device(accelerator_options.device)
self.device = device

_log.debug("Available device for SmolDocling: {}".format(device))
_log.debug("Available device for HuggingFace VLM: {}".format(device))

repo_cache_folder = self._repo_id.replace("/", "--")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")

# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models()
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder

self.param_question = vlm_options.question # "Perform Layout Analysis."
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
Expand All @@ -61,22 +63,27 @@ def __init__(

self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype=torch.bfloat16,
)
self.vlm_model = self.vlm_model.to(device)
# _attn_implementation=(
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
# ),
).to(self.device)

else:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
).to(device)
# _attn_implementation=(
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
# ),
).to(self.device)

@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
Expand All @@ -87,7 +94,7 @@ def download_models(
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=SmolDoclingModel._repo_id,
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
Expand Down Expand Up @@ -155,13 +162,13 @@ def __call__(
num_tokens = len(generated_ids[0])
page_tags = generated_texts

inference_time = time.time() - start_time
tokens_per_second = num_tokens / generation_time
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags)

yield page
Loading

0 comments on commit 1cba96e

Please sign in to comment.