Skip to content

Commit

Permalink
feat: add back source attribute for backward compatibility (#407)
Browse files Browse the repository at this point in the history
This PR adds `source` back to `TextRegions` and `LayoutElements` for
backward compatibility.
  • Loading branch information
badGarnet authored Jan 22, 2025
1 parent 655ea34 commit 85bcdc1
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 5 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
## 0.8.6

* feat: add back `source` to `TextRegions` and `LayoutElements` for backward compatibility

## 0.8.5

* fix: remove `pdfplumber` but include `pdfminer-six==20240706` to update `pdfminer`
* fix: remove `pdfplumber` but include `pdfminer-six==20240706` to update `pdfminer`

## 0.8.4

Expand Down
6 changes: 4 additions & 2 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_layoutelements():
element_coords=coords,
element_class_ids=element_class_ids,
element_class_id_map=class_map,
sources=np.array(["yolox"] * len(element_class_ids)),
source="yolox",
)


Expand Down Expand Up @@ -441,20 +441,22 @@ def test_layoutelements_to_list_and_back(test_layoutelements):
def test_layoutelements_from_list_no_elements():
back = LayoutElements.from_list(elements=[])
assert back.sources.size == 0
assert back.source is None
assert back.element_coords.size == 0


def test_textregions_from_list_no_elements():
back = TextRegions.from_list(regions=[])
assert back.sources.size == 0
assert back.source is None
assert back.element_coords.size == 0


def test_layoutelements_concatenate():
layout1 = LayoutElements(
element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]),
texts=np.array(["a", "two"]),
sources=np.array(["yolox", "yolox"]),
source="yolox",
element_class_ids=np.array([0, 1]),
element_class_id_map={0: "type0", 1: "type1"},
)
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.5" # pragma: no cover
__version__ = "0.8.6" # pragma: no cover
10 changes: 10 additions & 0 deletions unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,21 @@ class TextRegions:
element_coords: np.ndarray
texts: np.ndarray = field(default_factory=lambda: np.array([]))
sources: np.ndarray = field(default_factory=lambda: np.array([]))
source: Source | None = None

def __post_init__(self):
if self.texts.size == 0 and self.element_coords.size > 0:
self.texts = np.array([None] * self.element_coords.shape[0])

# for backward compatibility; also allow to use one value to set sources for all regions
if self.sources.size == 0 and self.element_coords.size > 0:
self.sources = np.array([self.source] * self.element_coords.shape[0])
elif self.source is None and self.sources.size:
self.source = self.sources[0]

# we convert to float so data type is more consistent (e.g., None will be np.nan)
self.element_coords = self.element_coords.astype(float)

def slice(self, indices) -> TextRegions:
"""slice text regions based on indices"""
return TextRegions(
Expand Down
7 changes: 6 additions & 1 deletion unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,18 @@ def __post_init__(self):
"element_probs",
"element_class_ids",
"texts",
"sources",
"text_as_html",
"table_as_cells",
):
if getattr(self, attr).size == 0 and element_size:
setattr(self, attr, np.array([None] * element_size))

# for backward compatibility; also allow to use one value to set sources for all regions
if self.sources.size == 0 and self.element_coords.size > 0:
self.sources = np.array([self.source] * self.element_coords.shape[0])
elif self.source is None and self.sources.size:
self.source = self.sources[0]

self.element_probs = self.element_probs.astype(float)

def __eq__(self, other: object) -> bool:
Expand Down

0 comments on commit 85bcdc1

Please sign in to comment.