From 6895ddbd214c8895ca866a691816bb823bca715b Mon Sep 17 00:00:00 2001 From: Yao You Date: Wed, 15 Jan 2025 08:20:43 -0600 Subject: [PATCH] fix: fix bugs in data structure (#402) - fix bug when an empty list is passed into `TextRegions.from_list` - fix bug when concatenating a list of `LayoutElements` the class id maps is not updated correctly --- CHANGELOG.md | 7 +++++ test_unstructured_inference/test_elements.py | 28 +++++++++++++++++++ unstructured_inference/__version__.py | 2 +- unstructured_inference/inference/elements.py | 3 +- .../inference/layoutelement.py | 17 ++++++++--- 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 75673742..19a3787e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,11 @@ +## 0.8.2 + +* fix: fix bug when an empty list is passed into `TextRegions.from_list` triggers `IndexError` +* fix: fix bug when concatenate a list of `LayoutElements` the class id mapping is no properly + updated + ## 0.8.1 + * fix: fix list index out of range error caused by calling LayoutElements.from_list() with empty list ## 0.8.0 diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index cb730335..071c0840 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -441,3 +441,31 @@ def test_layoutelements_to_list_and_back(test_layoutelements): def test_layoutelements_from_list_no_elements(): back = LayoutElements.from_list(elements=[]) assert back.source is None + assert back.element_coords.size == 0 + + +def test_textregions_from_list_no_elements(): + back = TextRegions.from_list(regions=[]) + assert back.source is None + assert back.element_coords.size == 0 + + +def test_layoutelements_concatenate(): + layout1 = LayoutElements( + element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]), + texts=np.array(["a", "two"]), + source=None, + element_class_ids=np.array([0, 1]), + element_class_id_map={0: "type0", 1: "type1"}, + ) + layout2 = LayoutElements( + element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]), + texts=np.array(["three", "4"]), + source=None, + element_class_ids=np.array([0, 1]), + element_class_id_map={0: "type1", 1: "type2"}, + ) + joint = LayoutElements.concatenate([layout1, layout2]) + assert joint.texts.tolist() == ["a", "two", "three", "4"] + assert joint.element_class_ids.tolist() == [0, 1, 1, 2] + assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"} diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 6719f022..c290d75e 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.8.1" # pragma: no cover +__version__ = "0.8.2" # pragma: no cover diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 939ea0cc..19fa02ff 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -244,7 +244,8 @@ def from_list(cls, regions: list): for region in regions: coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2)) texts.append(region.text) - return cls(element_coords=np.array(coords), texts=np.array(texts), source=regions[0].source) + source = regions[0].source if regions else None + return cls(element_coords=np.array(coords), texts=np.array(texts), source=source) def __len__(self): return self.element_coords.shape[0] diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index f1b40ab4..8ccaeb9c 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -74,22 +74,31 @@ def slice(self, indices) -> LayoutElements: def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: """concatenate a sequence of LayoutElements in order as one LayoutElements""" coords, texts, probs, class_ids, sources = [], [], [], [], [] - class_id_map = {} + class_id_reverse_map: dict[str, int] = {} for group in groups: coords.append(group.element_coords) texts.append(group.texts) probs.append(group.element_probs) - class_ids.append(group.element_class_ids) if group.source: sources.append(group.source) + + idx = group.element_class_ids.copy() if group.element_class_id_map: - class_id_map.update(group.element_class_id_map) + for class_id, class_name in group.element_class_id_map.items(): + if class_name in class_id_reverse_map: + idx[group.element_class_ids == class_id] = class_id_reverse_map[class_name] + continue + new_id = len(class_id_reverse_map) + class_id_reverse_map[class_name] = new_id + idx[group.element_class_ids == class_id] = new_id + class_ids.append(idx) + return cls( element_coords=np.concatenate(coords), texts=np.concatenate(texts), element_probs=np.concatenate(probs), element_class_ids=np.concatenate(class_ids), - element_class_id_map=class_id_map, + element_class_id_map={v: k for k, v in class_id_reverse_map.items()}, source=sources[0] if sources else None, )