diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 5bd28ed65..db8802b55 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -147,6 +147,10 @@ class LayoutPrediction(BaseModel): clusters: List[Cluster] = [] +class DocTagsPrediction(BaseModel): + tag_string: str = "" + + class ContainerElement( BasePageElement ): # Used for Form and Key-Value-Regions, only for typing. @@ -190,6 +194,7 @@ class PagePredictions(BaseModel): tablestructure: Optional[TableStructurePrediction] = None figures_classification: Optional[FigureClassificationPrediction] = None equations_prediction: Optional[EquationPrediction] = None + doctags: Optional[DocTagsPrediction] = None PageElement = Union[TextElement, Table, FigureElement, ContainerElement] diff --git a/docling/models/smol_docling_model.py b/docling/models/smol_docling_model.py new file mode 100644 index 000000000..056f59e8b --- /dev/null +++ b/docling/models/smol_docling_model.py @@ -0,0 +1,58 @@ +import copy +import logging +import random +import time +from pathlib import Path +from typing import Iterable, List + +from docling_core.types.doc import CoordOrigin, DocItemLabel +from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor +from PIL import Image, ImageDraw, ImageFont + +from docling.datamodel.base_models import ( + BoundingBox, + Cell, + Cluster, + DocTagsPrediction, + LayoutPrediction, + Page, +) +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions +from docling.datamodel.settings import settings +from docling.models.base_model import BasePageModel +from docling.utils.accelerator_utils import decide_device +from docling.utils.layout_postprocessor import LayoutPostprocessor +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class SmolDoclingModel(BasePageModel): + + def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): + device = decide_device(accelerator_options.device) + + # self.your_vlm_predictor(..., device) = None # TODO + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "smolvlm"): + assert page.size is not None + + hi_res_image = page.get_image(scale=2.0) # 144dpi + + # Call your self.your_vlm_predictor with the page image as input (hi_res_image) + # populate page_tags + page_tags = "" + + page.predictions.doctags = DocTagsPrediction(tag_string=page_tags) + + yield page diff --git a/docling/pipeline/smol_docling_pipeline.py b/docling/pipeline/smol_docling_pipeline.py new file mode 100644 index 000000000..f4373beb1 --- /dev/null +++ b/docling/pipeline/smol_docling_pipeline.py @@ -0,0 +1,162 @@ +import logging +import sys +from pathlib import Path +from typing import Optional + +from docling_core.types import DoclingDocument +from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem + +from docling.backend.abstract_backend import AbstractDocumentBackend +from docling.backend.pdf_backend import PdfDocumentBackend +from docling.datamodel.base_models import AssembledUnit, Page +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import ( + EasyOcrOptions, + OcrMacOptions, + PdfPipelineOptions, + RapidOcrOptions, + TesseractCliOcrOptions, + TesseractOcrOptions, +) +from docling.models.base_ocr_model import BaseOcrModel +from docling.models.ds_glm_model import GlmModel, GlmOptions +from docling.models.easyocr_model import EasyOcrModel +from docling.models.layout_model import LayoutModel +from docling.models.ocr_mac_model import OcrMacModel +from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions +from docling.models.page_preprocessing_model import ( + PagePreprocessingModel, + PagePreprocessingOptions, +) +from docling.models.rapid_ocr_model import RapidOcrModel +from docling.models.smol_docling_model import SmolDoclingModel +from docling.models.table_structure_model import TableStructureModel +from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel +from docling.models.tesseract_ocr_model import TesseractOcrModel +from docling.pipeline.base_pipeline import PaginatedPipeline +from docling.utils.profiling import ProfilingScope, TimeRecorder + +_log = logging.getLogger(__name__) + + +class SmolDoclingPipeline(PaginatedPipeline): + _smol_vlm_path = "model_artifacts/smol_vlm" # TODO or whatever is needed. + + def __init__(self, pipeline_options: PdfPipelineOptions): + super().__init__(pipeline_options) + self.pipeline_options: PdfPipelineOptions + + if pipeline_options.artifacts_path is None: + self.artifacts_path = self.download_models_hf() + else: + self.artifacts_path = Path(pipeline_options.artifacts_path) + + keep_images = ( + self.pipeline_options.generate_page_images + or self.pipeline_options.generate_picture_images + or self.pipeline_options.generate_table_images + ) + + self.build_pipe = [ + SmolDoclingModel( + artifacts_path=self.artifacts_path / SmolDoclingPipeline._smol_vlm_path, + accelerator_options=pipeline_options.accelerator_options, + ), + ] + + self.enrichment_pipe = [ + # Other models working on `NodeItem` elements in the DoclingDocument + ] + + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + + # TODO download the correct model (private repo) + download_path = snapshot_download( + repo_id="ds4sd/xxx", + force_download=force, + local_dir=local_dir, + ) + + return Path(download_path) + + def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: + with TimeRecorder(conv_res, "page_init"): + page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore + if page._backend is not None and page._backend.is_valid(): + page.size = page._backend.get_size() + + return page + + def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult: + + with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT): + + # Read and concatenate the page doctags: + document_tags = "" + for page in conv_res.pages: + if page.predictions.doctags is not None: + document_tags += page.predictions.doctags.tag_string + + # TODO implement this function + conv_res.document = self._turn_tags_into_doc(document_tags) + + # Generate page images in the output + if self.pipeline_options.generate_page_images: + for page in conv_res.pages: + assert page.image is not None + page_no = page.page_no + 1 + conv_res.document.pages[page_no].image = ImageRef.from_pil( + page.image, dpi=int(72 * self.pipeline_options.images_scale) + ) + + # Generate images of the requested element types + if ( + self.pipeline_options.generate_picture_images + or self.pipeline_options.generate_table_images + ): + scale = self.pipeline_options.images_scale + for element, _level in conv_res.document.iterate_items(): + if not isinstance(element, DocItem) or len(element.prov) == 0: + continue + if ( + isinstance(element, PictureItem) + and self.pipeline_options.generate_picture_images + ) or ( + isinstance(element, TableItem) + and self.pipeline_options.generate_table_images + ): + page_ix = element.prov[0].page_no - 1 + page = conv_res.pages[page_ix] + assert page.size is not None + assert page.image is not None + + crop_bbox = ( + element.prov[0] + .bbox.scaled(scale=scale) + .to_top_left_origin(page_height=page.size.height * scale) + ) + + cropped_im = page.image.crop(crop_bbox.as_tuple()) + element.image = ImageRef.from_pil( + cropped_im, dpi=int(72 * scale) + ) + + return conv_res + + @classmethod + def get_default_options(cls) -> PdfPipelineOptions: + return PdfPipelineOptions() + + @classmethod + def is_backend_supported(cls, backend: AbstractDocumentBackend): + return isinstance(backend, PdfDocumentBackend) + + def _turn_tags_into_doc(self, document_tags): + return DoclingDocument() diff --git a/docs/examples/minimal_smol_docling.py b/docs/examples/minimal_smol_docling.py new file mode 100644 index 000000000..7fcc19504 --- /dev/null +++ b/docs/examples/minimal_smol_docling.py @@ -0,0 +1,15 @@ +from docling.datamodel.base_models import InputFormat +from docling.document_converter import DocumentConverter, PdfFormatOption +from docling.pipeline.smol_docling_pipeline import SmolDoclingPipeline + +source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL +converter = DocumentConverter( + doc_converter=DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption(pipeline_cls=SmolDoclingPipeline) + } + ) +) +result = converter.convert(source) +print(result.document.export_to_markdown()) +# output: ## Docling Technical Report [...]"