Skip to content

Commit

Permalink
Addressing PR comments, added enabled property to SmolDocling, and re…
Browse files Browse the repository at this point in the history
…lated VLM pipeline option, few other minor things

Signed-off-by: Maksym Lysak <[email protected]>
  • Loading branch information
Maksym Lysak committed Feb 13, 2025
1 parent be11210 commit 9476e6a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 48 deletions.
3 changes: 2 additions & 1 deletion docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def repo_cache_folder(self) -> str:


class SmolDoclingOptions(BaseModel):
question: str = "Convert this page to docling." # "Perform Layout Analysis."
question: str = "Convert this page to docling."
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
Expand Down Expand Up @@ -275,6 +275,7 @@ class PaginatedPipelineOptions(PipelineOptions):

class VlmPipelineOptions(PaginatedPipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_vlm: bool = True # True: perform inference of Visual Language Model

force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
Expand Down
80 changes: 41 additions & 39 deletions docling/models/smol_docling_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
from pathlib import Path
from typing import Iterable, List, Optional

import torch
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
from transformers import ( # type: ignore
AutoProcessor,
BitsAndBytesConfig,
Idefics3ForConditionalGeneration,
)

from docling.datamodel.base_models import DocTagsPrediction, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
Expand All @@ -32,46 +24,56 @@ class SmolDoclingModel(BasePageModel):

def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: SmolDoclingOptions,
):
device = decide_device(accelerator_options.device)
self.device = device
self.enabled = enabled

if self.enabled:
import torch
from transformers import ( # type: ignore
AutoProcessor,
BitsAndBytesConfig,
Idefics3ForConditionalGeneration,
)

_log.debug("Available device for SmolDocling: {}".format(device))
device = decide_device(accelerator_options.device)
self.device = device

repo_cache_folder = self._repo_id.replace("/", "--")
_log.debug("Available device for SmolDocling: {}".format(device))

# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models()
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
repo_cache_folder = self._repo_id.replace("/", "--")

self.param_question = vlm_options.question # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.param_quantized = vlm_options.quantized # False

self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype=torch.bfloat16,
# _attn_implementation="flash_attention_2",
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models()
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder

self.param_question = vlm_options.question # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.vlm_model = self.vlm_model.to(device)
else:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
).to(device)
self.param_quantized = vlm_options.quantized # False

self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype=torch.bfloat16,
)
self.vlm_model = self.vlm_model.to(device)
else:
self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained(
artifacts_path,
# device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
).to(device)

@staticmethod
def download_models(
Expand Down
5 changes: 4 additions & 1 deletion docling/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.

def __init__(self, pipeline_options: PipelineOptions):
super().__init__(pipeline_options)
self.keep_backend = True
self.keep_backend = (
True # For now, need to be able to query for page size post prediction
)
# self.keep_backend = False

def _apply_on_pages(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
Expand Down
1 change: 0 additions & 1 deletion docling/pipeline/standard_pdf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class StandardPdfPipeline(PaginatedPipeline):

def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
print("------> Init Standard PDF Pipeline!")
self.pipeline_options: PdfPipelineOptions

artifacts_path: Optional[Path] = None
Expand Down
18 changes: 13 additions & 5 deletions docling/pipeline/vlm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, pipeline_options: VlmPipelineOptions):

self.build_pipe = [
SmolDoclingModel(
enabled=pipeline_options.do_vlm,
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
Expand Down Expand Up @@ -315,13 +316,15 @@ def otsl_extract_tokens_and_text(s: str):
token
for token in tokens
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
# if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL])
]
# Split the string by those tokens to get the in-between text
text_parts = re.split(pattern, s)
text_parts = [
token
for token in text_parts
if not (token.startswith("<loc_") or token in ["<otsl>", "</otsl>"])
# if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL])
]
# Remove any empty or purely whitespace strings from text_parts
text_parts = [part for part in text_parts if part.strip()]
Expand Down Expand Up @@ -365,10 +368,15 @@ def parse_table_content(otsl_content: str) -> TableData:

# Regex for all recognized tags
tag_pattern = (
r"<(?P<tag>title|document_index|otsl|section_header_level_1|checkbox_selected|"
r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|"
r"list_item|footnote|code)>.*?</(?P=tag)>"
rf"<(?P<tag>{DocItemLabel.TITLE}|{DocItemLabel.DOCUMENT_INDEX}|"
rf"{DocItemLabel.CHECKBOX_UNSELECTED}|{DocItemLabel.CHECKBOX_SELECTED}|"
rf"{DocItemLabel.TEXT}|{DocItemLabel.PAGE_HEADER}|"
rf"{DocItemLabel.PAGE_FOOTER}|{DocItemLabel.FORMULA}|"
rf"{DocItemLabel.CAPTION}|{DocItemLabel.PICTURE}|"
rf"{DocItemLabel.LIST_ITEM}|{DocItemLabel.FOOTNOTE}|{DocItemLabel.CODE}|"
rf"{DocItemLabel.SECTION_HEADER}_level_1|otsl)>.*?</(?P=tag)>"
)

pattern = re.compile(tag_pattern, re.DOTALL)

# Go through each match in order
Expand Down Expand Up @@ -456,8 +464,8 @@ def parse_table_content(otsl_content: str) -> TableData:
return doc

@classmethod
def get_default_options(cls) -> PdfPipelineOptions:
return PdfPipelineOptions()
def get_default_options(cls) -> VlmPipelineOptions:
return VlmPipelineOptions()

@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/minimal_smol_docling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
pipeline_options.generate_page_images = True
# If force_backend_text = True, text from backend will be used instead of generated text
pipeline_options.force_backend_text = False

# pipeline_options.do_vlm = True - use False to disable VLM model (i.e. SmallDocling), extra python imports will not be performed

vlm_options = SmolDoclingOptions(
# question="Convert this page to docling.",
Expand Down

0 comments on commit 9476e6a

Please sign in to comment.