Skip to content

Commit

Permalink
Skeleton for SmolDocling pipeline
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <[email protected]>
  • Loading branch information
cau-git committed Jan 8, 2025
1 parent ead396a commit e4a60ae
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docling/datamodel/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []


class DocTagsPrediction(BaseModel):
tag_string: str = ""


class ContainerElement(
BasePageElement
): # Used for Form and Key-Value-Regions, only for typing.
Expand Down Expand Up @@ -190,6 +194,7 @@ class PagePredictions(BaseModel):
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None
doctags: Optional[DocTagsPrediction] = None


PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
Expand Down
58 changes: 58 additions & 0 deletions docling/models/smol_docling_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import copy
import logging
import random
import time
from pathlib import Path
from typing import Iterable, List

from docling_core.types.doc import CoordOrigin, DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image, ImageDraw, ImageFont

from docling.datamodel.base_models import (
BoundingBox,
Cell,
Cluster,
DocTagsPrediction,
LayoutPrediction,
Page,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder

_log = logging.getLogger(__name__)


class SmolDoclingModel(BasePageModel):

def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
device = decide_device(accelerator_options.device)

# self.your_vlm_predictor(..., device) = None # TODO

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:

for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "smolvlm"):
assert page.size is not None

hi_res_image = page.get_image(scale=2.0) # 144dpi

# Call your self.your_vlm_predictor with the page image as input (hi_res_image)
# populate page_tags
page_tags = ""

page.predictions.doctags = DocTagsPrediction(tag_string=page_tags)

yield page
162 changes: 162 additions & 0 deletions docling/pipeline/smol_docling_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import logging
import sys
from pathlib import Path
from typing import Optional

from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem

from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
OcrMacOptions,
PdfPipelineOptions,
RapidOcrOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.smol_docling_model import SmolDoclingModel
from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder

_log = logging.getLogger(__name__)


class SmolDoclingPipeline(PaginatedPipeline):
_smol_vlm_path = "model_artifacts/smol_vlm" # TODO or whatever is needed.

def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions

if pipeline_options.artifacts_path is None:
self.artifacts_path = self.download_models_hf()
else:
self.artifacts_path = Path(pipeline_options.artifacts_path)

keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
)

self.build_pipe = [
SmolDoclingModel(
artifacts_path=self.artifacts_path / SmolDoclingPipeline._smol_vlm_path,
accelerator_options=pipeline_options.accelerator_options,
),
]

self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
]

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()

# TODO download the correct model (private repo)
download_path = snapshot_download(
repo_id="ds4sd/xxx",
force_download=force,
local_dir=local_dir,
)

return Path(download_path)

def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()

return page

def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:

with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):

# Read and concatenate the page doctags:
document_tags = ""
for page in conv_res.pages:
if page.predictions.doctags is not None:
document_tags += page.predictions.doctags.tag_string

# TODO implement this function
conv_res.document = self._turn_tags_into_doc(document_tags)

# Generate page images in the output
if self.pipeline_options.generate_page_images:
for page in conv_res.pages:
assert page.image is not None
page_no = page.page_no + 1
conv_res.document.pages[page_no].image = ImageRef.from_pil(
page.image, dpi=int(72 * self.pipeline_options.images_scale)
)

# Generate images of the requested element types
if (
self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
):
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, DocItem) or len(element.prov) == 0:
continue
if (
isinstance(element, PictureItem)
and self.pipeline_options.generate_picture_images
) or (
isinstance(element, TableItem)
and self.pipeline_options.generate_table_images
):
page_ix = element.prov[0].page_no - 1
page = conv_res.pages[page_ix]
assert page.size is not None
assert page.image is not None

crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(page_height=page.size.height * scale)
)

cropped_im = page.image.crop(crop_bbox.as_tuple())
element.image = ImageRef.from_pil(
cropped_im, dpi=int(72 * scale)
)

return conv_res

@classmethod
def get_default_options(cls) -> PdfPipelineOptions:
return PdfPipelineOptions()

@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
return isinstance(backend, PdfDocumentBackend)

def _turn_tags_into_doc(self, document_tags):
return DoclingDocument()
15 changes: 15 additions & 0 deletions docs/examples/minimal_smol_docling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.smol_docling_pipeline import SmolDoclingPipeline

source = "https://arxiv.org/pdf/2408.09869" # document per local path or URL
converter = DocumentConverter(
doc_converter=DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_cls=SmolDoclingPipeline)
}
)
)
result = converter.convert(source)
print(result.document.export_to_markdown())
# output: ## Docling Technical Report [...]"

0 comments on commit e4a60ae

Please sign in to comment.