From 9476e6a56652662cf3a249dbd6029ec4f5bca24a Mon Sep 17 00:00:00 2001 From: Maksym Lysak Date: Thu, 13 Feb 2025 17:19:53 +0100 Subject: [PATCH] Addressing PR comments, added enabled property to SmolDocling, and related VLM pipeline option, few other minor things Signed-off-by: Maksym Lysak --- docling/datamodel/pipeline_options.py | 3 +- docling/models/smol_docling_model.py | 80 ++++++++++++----------- docling/pipeline/base_pipeline.py | 5 +- docling/pipeline/standard_pdf_pipeline.py | 1 - docling/pipeline/vlm_pipeline.py | 18 +++-- docs/examples/minimal_smol_docling.py | 2 +- 6 files changed, 61 insertions(+), 48 deletions(-) diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 3c4ec1fdd..86d808c24 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -230,7 +230,7 @@ def repo_cache_folder(self) -> str: class SmolDoclingOptions(BaseModel): - question: str = "Convert this page to docling." # "Perform Layout Analysis." + question: str = "Convert this page to docling." load_in_8bit: bool = True llm_int8_threshold: float = 6.0 quantized: bool = False @@ -275,6 +275,7 @@ class PaginatedPipelineOptions(PipelineOptions): class VlmPipelineOptions(PaginatedPipelineOptions): artifacts_path: Optional[Union[Path, str]] = None + do_vlm: bool = True # True: perform inference of Visual Language Model force_backend_text: bool = ( False # (To be used with vlms, or other generative models) diff --git a/docling/models/smol_docling_model.py b/docling/models/smol_docling_model.py index a66029c36..00c04fa7d 100644 --- a/docling/models/smol_docling_model.py +++ b/docling/models/smol_docling_model.py @@ -3,14 +3,6 @@ from pathlib import Path from typing import Iterable, List, Optional -import torch -from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS -from transformers import ( # type: ignore - AutoProcessor, - BitsAndBytesConfig, - Idefics3ForConditionalGeneration, -) - from docling.datamodel.base_models import DocTagsPrediction, Page from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( @@ -32,46 +24,56 @@ class SmolDoclingModel(BasePageModel): def __init__( self, + enabled: bool, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions, vlm_options: SmolDoclingOptions, ): - device = decide_device(accelerator_options.device) - self.device = device + self.enabled = enabled + + if self.enabled: + import torch + from transformers import ( # type: ignore + AutoProcessor, + BitsAndBytesConfig, + Idefics3ForConditionalGeneration, + ) - _log.debug("Available device for SmolDocling: {}".format(device)) + device = decide_device(accelerator_options.device) + self.device = device - repo_cache_folder = self._repo_id.replace("/", "--") + _log.debug("Available device for SmolDocling: {}".format(device)) - # PARAMETERS: - if artifacts_path is None: - artifacts_path = self.download_models() - elif (artifacts_path / repo_cache_folder).exists(): - artifacts_path = artifacts_path / repo_cache_folder + repo_cache_folder = self._repo_id.replace("/", "--") - self.param_question = vlm_options.question # "Perform Layout Analysis." - self.param_quantization_config = BitsAndBytesConfig( - load_in_8bit=vlm_options.load_in_8bit, # True, - llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0 - ) - self.param_quantized = vlm_options.quantized # False - - self.processor = AutoProcessor.from_pretrained(artifacts_path) - if not self.param_quantized: - self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained( - artifacts_path, - # device_map=device, - torch_dtype=torch.bfloat16, - # _attn_implementation="flash_attention_2", + # PARAMETERS: + if artifacts_path is None: + artifacts_path = self.download_models() + elif (artifacts_path / repo_cache_folder).exists(): + artifacts_path = artifacts_path / repo_cache_folder + + self.param_question = vlm_options.question # "Perform Layout Analysis." + self.param_quantization_config = BitsAndBytesConfig( + load_in_8bit=vlm_options.load_in_8bit, # True, + llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0 ) - self.vlm_model = self.vlm_model.to(device) - else: - self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained( - artifacts_path, - # device_map=device, - torch_dtype="auto", - quantization_config=self.param_quantization_config, - ).to(device) + self.param_quantized = vlm_options.quantized # False + + self.processor = AutoProcessor.from_pretrained(artifacts_path) + if not self.param_quantized: + self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained( + artifacts_path, + # device_map=device, + torch_dtype=torch.bfloat16, + ) + self.vlm_model = self.vlm_model.to(device) + else: + self.vlm_model = Idefics3ForConditionalGeneration.from_pretrained( + artifacts_path, + # device_map=device, + torch_dtype="auto", + quantization_config=self.param_quantization_config, + ).to(device) @staticmethod def download_models( diff --git a/docling/pipeline/base_pipeline.py b/docling/pipeline/base_pipeline.py index 01ed71a09..d08cf85a2 100644 --- a/docling/pipeline/base_pipeline.py +++ b/docling/pipeline/base_pipeline.py @@ -116,7 +116,10 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name. def __init__(self, pipeline_options: PipelineOptions): super().__init__(pipeline_options) - self.keep_backend = True + self.keep_backend = ( + True # For now, need to be able to query for page size post prediction + ) + # self.keep_backend = False def _apply_on_pages( self, conv_res: ConversionResult, page_batch: Iterable[Page] diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index ce52da27e..ae4ed4780 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -56,7 +56,6 @@ class StandardPdfPipeline(PaginatedPipeline): def __init__(self, pipeline_options: PdfPipelineOptions): super().__init__(pipeline_options) - print("------> Init Standard PDF Pipeline!") self.pipeline_options: PdfPipelineOptions artifacts_path: Optional[Path] = None diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index fd6e9dfce..31e40e7c8 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -117,6 +117,7 @@ def __init__(self, pipeline_options: VlmPipelineOptions): self.build_pipe = [ SmolDoclingModel( + enabled=pipeline_options.do_vlm, artifacts_path=artifacts_path, accelerator_options=pipeline_options.accelerator_options, vlm_options=self.pipeline_options.vlm_options, @@ -315,6 +316,7 @@ def otsl_extract_tokens_and_text(s: str): token for token in tokens if not (token.startswith("", ""]) + # if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL]) ] # Split the string by those tokens to get the in-between text text_parts = re.split(pattern, s) @@ -322,6 +324,7 @@ def otsl_extract_tokens_and_text(s: str): token for token in text_parts if not (token.startswith("", ""]) + # if not (token.startswith(DocumentToken.BEG_LOC) or token in [DocumentToken.BEG_OTSL, DocumentToken.END_OTSL]) ] # Remove any empty or purely whitespace strings from text_parts text_parts = [part for part in text_parts if part.strip()] @@ -365,10 +368,15 @@ def parse_table_content(otsl_content: str) -> TableData: # Regex for all recognized tags tag_pattern = ( - r"<(?Ptitle|document_index|otsl|section_header_level_1|checkbox_selected|" - r"checkbox_unselected|text|page_header|page_footer|formula|caption|picture|" - r"list_item|footnote|code)>.*?" + rf"<(?P{DocItemLabel.TITLE}|{DocItemLabel.DOCUMENT_INDEX}|" + rf"{DocItemLabel.CHECKBOX_UNSELECTED}|{DocItemLabel.CHECKBOX_SELECTED}|" + rf"{DocItemLabel.TEXT}|{DocItemLabel.PAGE_HEADER}|" + rf"{DocItemLabel.PAGE_FOOTER}|{DocItemLabel.FORMULA}|" + rf"{DocItemLabel.CAPTION}|{DocItemLabel.PICTURE}|" + rf"{DocItemLabel.LIST_ITEM}|{DocItemLabel.FOOTNOTE}|{DocItemLabel.CODE}|" + rf"{DocItemLabel.SECTION_HEADER}_level_1|otsl)>.*?" ) + pattern = re.compile(tag_pattern, re.DOTALL) # Go through each match in order @@ -456,8 +464,8 @@ def parse_table_content(otsl_content: str) -> TableData: return doc @classmethod - def get_default_options(cls) -> PdfPipelineOptions: - return PdfPipelineOptions() + def get_default_options(cls) -> VlmPipelineOptions: + return VlmPipelineOptions() @classmethod def is_backend_supported(cls, backend: AbstractDocumentBackend): diff --git a/docs/examples/minimal_smol_docling.py b/docs/examples/minimal_smol_docling.py index 5d64dee4e..66252f7b6 100644 --- a/docs/examples/minimal_smol_docling.py +++ b/docs/examples/minimal_smol_docling.py @@ -19,7 +19,7 @@ pipeline_options.generate_page_images = True # If force_backend_text = True, text from backend will be used instead of generated text pipeline_options.force_backend_text = False - +# pipeline_options.do_vlm = True - use False to disable VLM model (i.e. SmallDocling), extra python imports will not be performed vlm_options = SmolDoclingOptions( # question="Convert this page to docling.",