Skip to content

Commit

Permalink
feat: [Experimental] Introduce VLM pipeline using HF AutoModelForVisi…
Browse files Browse the repository at this point in the history
…on2Seq, featuring SmolDocling model (#1054)

* Skeleton for SmolDocling model and VLM Pipeline

Signed-off-by: Christoph Auer <[email protected]>
Signed-off-by: Maksym Lysak <[email protected]>

* wip smolDocling inference and vlm pipeline

Signed-off-by: Maksym Lysak <[email protected]>

* WIP, first working code for inference of SmolDocling, and vlm pipeline assembly code, example included.

Signed-off-by: Maksym Lysak <[email protected]>

* Fixes to preserve page image and demo export to html

Signed-off-by: Maksym Lysak <[email protected]>

* Enabled figure support in vlm_pipeline

Signed-off-by: Maksym Lysak <[email protected]>

* Fix for table span compute in vlm_pipeline

Signed-off-by: Maksym Lysak <[email protected]>

* Properly propagating image data per page, together with predicted tags in VLM pipeline. This enables correct figure extraction and page numbers in provenances

Signed-off-by: Maksym Lysak <[email protected]>

* Cleaned up logs, added pages to vlm_pipeline, basic timing per page measurement in smol_docling models

Signed-off-by: Maksym Lysak <[email protected]>

* Replaced hardcoded otsl tokens with the ones from docling-core tokens.py enum

Signed-off-by: Maksym Lysak <[email protected]>

* Added tokens/sec measurement, improved example

Signed-off-by: Maksym Lysak <[email protected]>

* Added capability for vlm_pipeline to grab text from preconfigured backend

Signed-off-by: Maksym Lysak <[email protected]>

* Exposed "force_backend_text" as pipeline parameter

Signed-off-by: Maksym Lysak <[email protected]>

* Flipped keep_backend to True for vlm_pipeline assembly to work

Signed-off-by: Maksym Lysak <[email protected]>

* Updated vlm pipeline assembly and smol docling model code to support updated doctags

Signed-off-by: Maksym Lysak <[email protected]>

* Fixing doctags starting tag, that broke elements on first line during assembly

Signed-off-by: Maksym Lysak <[email protected]>

* Introduced SmolDoclingOptions to configure model parameters (such as query and artifacts path) via client code, see example in minimal_smol_docling. Provisioning for other potential vlm all-in-one models.

Signed-off-by: Maksym Lysak <[email protected]>

* Moved artifacts_path for SmolDocling into vlm_options instead of global pipeline option

Signed-off-by: Maksym Lysak <[email protected]>

* New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging

Signed-off-by: Maksym Lysak <[email protected]>

* Updated example of Smol Docling usage

Signed-off-by: Maksym Lysak <[email protected]>

* Added captions for the images for SmolDocling assembly code, improved provenance definition for all elements

Signed-off-by: Maksym Lysak <[email protected]>

* Update minimal smoldocling example

Signed-off-by: Christoph Auer <[email protected]>

* Fix repo id

Signed-off-by: Christoph Auer <[email protected]>

* Cleaned up unnecessary logging

Signed-off-by: Maksym Lysak <[email protected]>

* More elegant solution in removing the input prompt

Signed-off-by: Maksym Lysak <[email protected]>

* removed minimal_smol_docling example from CI checks

Signed-off-by: Maksym Lysak <[email protected]>

* Removed special html code wrapping when exporting to docling document, cleaned up comments

Signed-off-by: Maksym Lysak <[email protected]>

* Addressing PR comments, added enabled property to SmolDocling, and related VLM pipeline option, few other minor things

Signed-off-by: Maksym Lysak <[email protected]>

* Moved keep_backend = True to vlm pipeline

Signed-off-by: Maksym Lysak <[email protected]>

* removed pipeline_options.generate_table_images from vlm_pipeline (deprecated in the pipelines)

Signed-off-by: Maksym Lysak <[email protected]>

* Added example on how to get original predicted doctags in minimal_smol_docling

Signed-off-by: Maksym Lysak <[email protected]>

* removing changes from base_pipeline

Signed-off-by: Maksym Lysak <[email protected]>

* Replaced remaining strings to appropriate enums

Signed-off-by: Maksym Lysak <[email protected]>

* Updated poetry.lock

Signed-off-by: Maksym Lysak <[email protected]>

* re-built poetry.lock

Signed-off-by: Maksym Lysak <[email protected]>

* Generalize and refactor VLM pipeline and models

Signed-off-by: Christoph Auer <[email protected]>

* Rename example

Signed-off-by: Christoph Auer <[email protected]>

* Move imports

Signed-off-by: Christoph Auer <[email protected]>

* Expose control over using flash_attention_2

Signed-off-by: Christoph Auer <[email protected]>

* Fix VLM example exclusion in CI

Signed-off-by: Christoph Auer <[email protected]>

* Add back device_map and accelerate

Signed-off-by: Christoph Auer <[email protected]>

* Make drawing code resilient against bad bboxes

Signed-off-by: Christoph Auer <[email protected]>

* chore: clean up code and comments

Signed-off-by: Christoph Auer <[email protected]>

* chore: more cleanup

Signed-off-by: Christoph Auer <[email protected]>

* chore: fix leftover .to(device)

Signed-off-by: Christoph Auer <[email protected]>

* fix: add proper table provenance

Signed-off-by: Christoph Auer <[email protected]>

---------

Signed-off-by: Christoph Auer <[email protected]>
Signed-off-by: Maksym Lysak <[email protected]>
Co-authored-by: Maksym Lysak <[email protected]>
  • Loading branch information
cau-git and Maksym Lysak authored Feb 26, 2025
1 parent ab683e4 commit 3c9fe76
Show file tree
Hide file tree
Showing 9 changed files with 1,248 additions and 316 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
echo "Skipping $file"
continue
fi
Expand Down
5 changes: 5 additions & 0 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []


class VlmPrediction(BaseModel):
text: str = ""


class ContainerElement(
BasePageElement
): # Used for Form and Key-Value-Regions, only for typing.
Expand Down Expand Up @@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None
vlm_response: Optional[VlmPrediction] = None


PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
Expand Down
63 changes: 62 additions & 1 deletion docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings):

num_threads: int = 4
device: Union[str, AcceleratorDevice] = "auto"
cuda_use_flash_attention2: bool = False

@field_validator("device")
def validate_device(cls, value):
Expand Down Expand Up @@ -254,6 +255,45 @@ def repo_cache_folder(self) -> str:
)


class BaseVlmOptions(BaseModel):
kind: str
prompt: str


class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"


class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"

repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False

response_format: ResponseFormat

@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")


smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
)

granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
response_format=ResponseFormat.MARKDOWN,
)


# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""
Expand Down Expand Up @@ -285,7 +325,24 @@ class PipelineOptions(BaseModel):
enable_remote_services: bool = False


class PdfPipelineOptions(PipelineOptions):
class PaginatedPipelineOptions(PipelineOptions):
images_scale: float = 1.0
generate_page_images: bool = False
generate_picture_images: bool = False


class VlmPipelineOptions(PaginatedPipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None

generate_page_images: bool = True
force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options


class PdfPipelineOptions(PaginatedPipelineOptions):
"""Options for the PDF pipeline."""

artifacts_path: Optional[Union[Path, str]] = None
Expand All @@ -295,6 +352,10 @@ class PdfPipelineOptions(PipelineOptions):
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
do_picture_classification: bool = False # True: classify pictures in documents
do_picture_description: bool = False # True: run describe pictures in documents
force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
Expand Down
180 changes: 180 additions & 0 deletions docling/models/hf_vlm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import logging
import time
from pathlib import Path
from typing import Iterable, List, Optional

from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder

_log = logging.getLogger(__name__)


class HuggingFaceVlmModel(BasePageModel):

def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled

self.vlm_options = vlm_options

if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)

device = decide_device(accelerator_options.device)
self.device = device

_log.debug("Available device for HuggingFace VLM: {}".format(device))

repo_cache_folder = vlm_options.repo_id.replace("/", "--")

# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder

self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.param_quantized = vlm_options.quantized # False

self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)

else:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)

@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)

return Path(download_path)

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None

hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi

if hi_res_image is not None:
im_width, im_height = hi_res_image.size

# populate page_tags with predicted doc tags
page_tags = ""

if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.param_question},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096, use_cache=True
)

generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]

num_tokens = len(generated_ids[0])
page_tags = generated_texts

# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)

yield page
Loading

0 comments on commit 3c9fe76

Please sign in to comment.