diff --git a/CHANGELOG.md b/CHANGELOG.md index cc849297..0e4e0183 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 1.1.0 + +* Enhancement: Add `TextSource` to track where the text of an element came from +* Enhancement: Refactor `__post_init__` of `TextRegions` and `LayoutElement` slightly to automate initialization + ## 1.0.10 * Remove merging logic that's no longer used diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 192bf8ed..b7756de2 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -13,6 +13,7 @@ EmbeddedTextRegion, ImageTextRegion, ) +from unstructured_inference.constants import IsExtracted from unstructured_inference.models.unstructuredmodel import ( UnstructuredElementExtractionModel, UnstructuredObjectDetectionModel, @@ -34,7 +35,7 @@ def mock_initial_layout(): 6, 8, text="A very repetitive narrative. " * 10, - source="Mock", + is_extracted=IsExtracted.TRUE, ) title_block = EmbeddedTextRegion.from_coords( @@ -43,7 +44,7 @@ def mock_initial_layout(): 3, 4, text="A Catchy Title", - source="Mock", + is_extracted=IsExtracted.TRUE, ) return [text_block, title_block] diff --git a/test_unstructured_inference/inference/test_layout_element.py b/test_unstructured_inference/inference/test_layout_element.py index b0cd2966..49b71100 100644 --- a/test_unstructured_inference/inference/test_layout_element.py +++ b/test_unstructured_inference/inference/test_layout_element.py @@ -1,10 +1,12 @@ from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion +from unstructured_inference.constants import IsExtracted, Source -def test_layout_element_do_dict(mock_layout_element): +def test_layout_element_to_dict(mock_layout_element): expected = { "coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)), "text": "Sample text", + "is_extracted": None, "type": "Text", "prob": None, "source": None, @@ -18,3 +20,31 @@ def test_layout_element_from_region(mock_rectangle): region = TextRegion(bbox=mock_rectangle) assert LayoutElement.from_region(region) == expected + + +def test_layoutelement_inheritance_works_correctly(): + """Test that LayoutElement properly inherits from TextRegion without conflicts""" + from unstructured_inference.inference.elements import TextRegion + + # Create a TextRegion with both source and text_source + region = TextRegion.from_coords( + 0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE + ) + + # Convert to LayoutElement + element = LayoutElement.from_region(region) + + # Check that both properties are preserved + assert element.source == Source.YOLOX, "LayoutElement should inherit source from TextRegion" + assert ( + element.is_extracted == IsExtracted.TRUE + ), "LayoutElement should inherit is_extracted from TextRegion" + + # Check that to_dict() works correctly + d = element.to_dict() + assert d["source"] == Source.YOLOX + assert d["is_extracted"] == IsExtracted.TRUE + + # Check that we can set source directly on LayoutElement + element.source = Source.DETECTRON2_ONNX + assert element.source == Source.DETECTRON2_ONNX diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index 1dc23595..7a8b937f 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -5,9 +5,11 @@ import numpy as np import pytest +from unstructured_inference.constants import IsExtracted, Source from unstructured_inference.inference import elements from unstructured_inference.inference.elements import ( Rectangle, + TextRegion, TextRegions, ) from unstructured_inference.inference.layoutelement import ( @@ -56,7 +58,7 @@ def test_layoutelements(): element_coords=coords, element_class_ids=element_class_ids, element_class_id_map=class_map, - source="yolox", + source=Source.YOLOX, ) @@ -307,7 +309,7 @@ def test_clean_layoutelements(test_layoutelements): elements[1].bbox.x2, elements[1].bbox.x2, ) == (2, 2, 3, 3) - assert elements[0].source == elements[1].source == "yolox" + assert elements[0].source == elements[1].source == Source.YOLOX @pytest.mark.parametrize( @@ -408,8 +410,8 @@ def test_layoutelements_from_list_no_elements(): def test_textregions_from_list_no_elements(): back = TextRegions.from_list(regions=[]) - assert back.sources.size == 0 - assert back.source is None + assert back.is_extracted_array.size == 0 + assert back.is_extracted is None assert back.element_coords.size == 0 @@ -417,20 +419,25 @@ def test_layoutelements_concatenate(): layout1 = LayoutElements( element_coords=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]), texts=np.array(["a", "two"]), - source="yolox", + source=Source.YOLOX, element_class_ids=np.array([0, 1]), element_class_id_map={0: "type0", 1: "type1"}, ) layout2 = LayoutElements( element_coords=np.array([[10, 10, 2, 2], [20, 20, 1, 1]]), texts=np.array(["three", "4"]), - sources=np.array(["ocr", "ocr"]), + sources=np.array([Source.DETECTRON2_ONNX, Source.DETECTRON2_ONNX]), element_class_ids=np.array([0, 1]), element_class_id_map={0: "type1", 1: "type2"}, ) joint = LayoutElements.concatenate([layout1, layout2]) assert joint.texts.tolist() == ["a", "two", "three", "4"] - assert joint.sources.tolist() == ["yolox", "yolox", "ocr", "ocr"] + assert [s.value for s in joint.sources.tolist()] == [ + "yolox", + "yolox", + "detectron2_onnx", + "detectron2_onnx", + ] assert joint.element_class_ids.tolist() == [0, 1, 1, 2] assert joint.element_class_id_map == {0: "type0", 1: "type1", 2: "type2"} @@ -449,8 +456,8 @@ def test_layoutelements_concatenate(): ] ), texts=np.array(["0", "1", "2", "3", "4"]), - sources=np.array(["foo", "foo", "foo", "foo", "foo"], dtype=" 0, "sources array should not be empty" + assert text_regions.sources[0] == Source.YOLOX + assert text_regions.sources[1] == Source.DETECTRON2_ONNX + + +def test_textregions_has_sources_field(): + """Test that TextRegions has a sources field""" + text_regions = TextRegions(element_coords=np.array([[0, 0, 10, 10]])) + + # This should fail because TextRegions doesn't have a sources field + assert hasattr(text_regions, "sources"), "TextRegions should have a sources field" + assert hasattr(text_regions, "source"), "TextRegions should have a source field" + + +def test_textregions_iter_elements_preserves_source(): + """Test that TextRegions.iter_elements() preserves source property""" + from unstructured_inference.inference.elements import TextRegion + + regions = [ + TextRegion.from_coords( + 0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE + ), + ] + text_regions = TextRegions.from_list(regions) + + elements = list(text_regions.iter_elements()) + + # This should fail because iter_elements() doesn't pass source to TextRegion.from_coords() + assert elements[0].source == Source.YOLOX, "iter_elements() should preserve source" + + +def test_textregions_slice_preserves_sources(): + """Test that TextRegions slicing preserves sources array""" + from unstructured_inference.inference.elements import TextRegion + + regions = [ + TextRegion.from_coords( + 0, 0, 10, 10, text="first", source=Source.YOLOX, is_extracted=IsExtracted.TRUE + ), + TextRegion.from_coords( + 10, + 10, + 20, + 20, + text="second", + source=Source.DETECTRON2_ONNX, + is_extracted=IsExtracted.TRUE, + ), + ] + text_regions = TextRegions.from_list(regions) + + sliced = text_regions[0:1] + + # This should fail because slice() doesn't handle sources + assert sliced.sources.size > 0, "Sliced TextRegions should have sources" + assert sliced.sources[0] == Source.YOLOX + assert sliced.is_extracted_array[0] is IsExtracted.TRUE + + +def test_textregions_post_init_handles_sources(): + """Test that TextRegions.__post_init__() handles sources array initialization""" + # Create with source but no sources array + text_regions = TextRegions( + element_coords=np.array([[0, 0, 10, 10], [10, 10, 20, 20]]), source=Source.YOLOX + ) + + # This should fail because __post_init__() doesn't handle sources + assert text_regions.sources.size > 0, "sources should be initialized from source" + assert text_regions.sources[0] == Source.YOLOX + assert text_regions.sources[1] == Source.YOLOX + + +def test_textregions_from_coords_accepts_source(): + """Test that TextRegion.from_coords() accepts source parameter""" + # This should fail because from_coords() doesn't accept source parameter + region = TextRegion.from_coords( + 0, 0, 10, 10, text="test", source=Source.YOLOX, is_extracted=IsExtracted.TRUE + ) + + assert region.source == Source.YOLOX + assert region.is_extracted diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index a2ba73b7..1bccd442 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "1.0.10" # pragma: no cover +__version__ = "1.1.0" # pragma: no cover diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py index c68e02d3..a3603c9d 100644 --- a/unstructured_inference/constants.py +++ b/unstructured_inference/constants.py @@ -7,6 +7,12 @@ class Source(Enum): DETECTRON2_LP = "detectron2_lp" +class IsExtracted(Enum): + TRUE = "true" + FALSE = "false" + PARTIAL = "partial" + + class ElementType: PARAGRAPH = "Paragraph" IMAGE = "Image" diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 81647ced..4e1791b5 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -7,7 +7,7 @@ import numpy as np -from unstructured_inference.constants import Source +from unstructured_inference.constants import IsExtracted, Source from unstructured_inference.math import safe_division @@ -185,6 +185,7 @@ class TextRegion: bbox: Rectangle text: Optional[str] = None source: Optional[Source] = None + is_extracted: Optional[IsExtracted] = None def __str__(self) -> str: return str(self.text) @@ -198,12 +199,13 @@ def from_coords( y2: Union[int, float], text: Optional[str] = None, source: Optional[Source] = None, + is_extracted: Optional[IsExtracted] = None, **kwargs, ) -> TextRegion: """Constructs a region from coordinates.""" bbox = Rectangle(x1, y1, x2, y2) - return cls(text=text, source=source, bbox=bbox, **kwargs) + return cls(text=text, source=source, is_extracted=is_extracted, bbox=bbox, **kwargs) @dataclass @@ -212,16 +214,33 @@ class TextRegions: texts: np.ndarray = field(default_factory=lambda: np.array([])) sources: np.ndarray = field(default_factory=lambda: np.array([])) source: Source | None = None + is_extracted_array: np.ndarray = field(default_factory=lambda: np.array([])) + is_extracted: IsExtracted | None = None + _optional_array_attributes: list[str] = field( + init=False, default_factory=lambda: ["texts", "sources", "is_extracted_array"] + ) + _scalar_to_array_mappings: dict[str, str] = field( + init=False, + default_factory=lambda: { + "source": "sources", + "is_extracted": "is_extracted_array", + }, + ) def __post_init__(self): - if self.texts.size == 0 and self.element_coords.size > 0: - self.texts = np.array([None] * self.element_coords.shape[0]) - - # for backward compatibility; also allow to use one value to set sources for all regions - if self.sources.size == 0 and self.element_coords.size > 0: - self.sources = np.array([self.source] * self.element_coords.shape[0]) - elif self.source is None and self.sources.size: - self.source = self.sources[0] + element_size = self.element_coords.shape[0] + for scalar, array in self._scalar_to_array_mappings.items(): + if ( + getattr(self, scalar) is not None + and getattr(self, array).size == 0 + and element_size + ): + setattr(self, array, np.array([getattr(self, scalar)] * element_size)) + elif getattr(self, scalar) is None and getattr(self, array).size > 0: + setattr(self, scalar, getattr(self, array)[0]) + for attr in self._optional_array_attributes: + if getattr(self, attr).size == 0 and element_size: + setattr(self, attr, np.array([None] * element_size)) # we convert to float so data type is more consistent (e.g., None will be np.nan) self.element_coords = self.element_coords.astype(float) @@ -231,21 +250,26 @@ def __getitem__(self, indices) -> TextRegions: def slice(self, indices) -> TextRegions: """slice text regions based on indices""" + # NOTE(alan): I would expect if I try to access a single element, it should return a + # TextRegion, not a TextRegions. Currently, you get an error when trying to access a single + # element. return TextRegions( element_coords=self.element_coords[indices], texts=self.texts[indices], sources=self.sources[indices], + is_extracted_array=self.is_extracted_array[indices], ) def iter_elements(self): """iter text regions as one TextRegion per iteration; this returns a generator and has less memory impact than the as_list method""" - for (x1, y1, x2, y2), text, source in zip( + for (x1, y1, x2, y2), text, source, is_extracted in zip( self.element_coords, self.texts, self.sources, + self.is_extracted_array, ): - yield TextRegion.from_coords(x1, y1, x2, y2, text, source) + yield TextRegion.from_coords(x1, y1, x2, y2, text, source, is_extracted) def as_list(self): """return a list of LayoutElement for backward compatibility""" @@ -254,16 +278,18 @@ def as_list(self): @classmethod def from_list(cls, regions: list): """create TextRegions from a list of TextRegion objects; the objects must have the same - source""" - coords, texts, sources = [], [], [] + is_extracted""" + coords, texts, sources, is_extracted_array = [], [], [], [] for region in regions: coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2)) texts.append(region.text) sources.append(region.source) + is_extracted_array.append(region.is_extracted) return cls( element_coords=np.array(coords), texts=np.array(texts), sources=np.array(sources), + is_extracted_array=np.array(is_extracted_array), ) def __len__(self): diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 39ce1f51..fb3e073e 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -1,13 +1,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List, Optional, Union import numpy as np from pandas import DataFrame from scipy.sparse.csgraph import connected_components from unstructured_inference.config import inference_config +from unstructured_inference.constants import IsExtracted, Source from unstructured_inference.inference.elements import ( Rectangle, TextRegion, @@ -25,27 +26,28 @@ class LayoutElements(TextRegions): element_class_id_map: dict[int, str] = field(default_factory=dict) text_as_html: np.ndarray = field(default_factory=lambda: np.array([])) table_as_cells: np.ndarray = field(default_factory=lambda: np.array([])) - - def __post_init__(self): - element_size = self.element_coords.shape[0] - # NOTE: maybe we should create an attribute _optional_attributes: list[str] to store this - # list - for attr in ( + _optional_array_attributes: list[str] = field( + init=False, + default_factory=lambda: [ + "texts", + "sources", + "is_extracted_array", "element_probs", "element_class_ids", - "texts", "text_as_html", "table_as_cells", - ): - if getattr(self, attr).size == 0 and element_size: - setattr(self, attr, np.array([None] * element_size)) - - # for backward compatibility; also allow to use one value to set sources for all regions - if self.sources.size == 0 and self.element_coords.size > 0: - self.sources = np.array([self.source] * self.element_coords.shape[0]) - elif self.source is None and self.sources.size: - self.source = self.sources[0] + ], + ) + _scalar_to_array_mappings: dict[str, str] = field( + init=False, + default_factory=lambda: { + "source": "sources", + "is_extracted": "is_extracted_array", + }, + ) + def __post_init__(self): + super().__post_init__() self.element_probs = self.element_probs.astype(float) def __eq__(self, other: object) -> bool: @@ -64,6 +66,7 @@ def __eq__(self, other: object) -> bool: == [other.element_class_id_map[idx] for idx in other.element_class_ids] ) and np.array_equal(self.sources[mask], other.sources[mask]) + and np.array_equal(self.is_extracted_array[mask], other.is_extracted_array[mask]) and np.array_equal(self.text_as_html[mask], other.text_as_html[mask]) and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask]) ) @@ -76,6 +79,7 @@ def slice(self, indices) -> LayoutElements: return LayoutElements( element_coords=self.element_coords[indices], texts=self.texts[indices], + is_extracted_array=self.is_extracted_array[indices], sources=self.sources[indices], element_probs=self.element_probs[indices], element_class_ids=self.element_class_ids[indices], @@ -87,7 +91,7 @@ def slice(self, indices) -> LayoutElements: @classmethod def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: """concatenate a sequence of LayoutElements in order as one LayoutElements""" - coords, texts, probs, class_ids, sources = [], [], [], [], [] + coords, texts, probs, class_ids, sources, is_extracted_array = [], [], [], [], [], [] text_as_html, table_as_cells = [], [] class_id_reverse_map: dict[str, int] = {} for group in groups: @@ -95,6 +99,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: texts.append(group.texts) probs.append(group.element_probs) sources.append(group.sources) + is_extracted_array.append(group.is_extracted_array) text_as_html.append(group.text_as_html) table_as_cells.append(group.table_as_cells) @@ -116,6 +121,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: element_class_ids=np.concatenate(class_ids), element_class_id_map={v: k for k, v in class_id_reverse_map.items()}, sources=np.concatenate(sources), + is_extracted_array=np.concatenate(is_extracted_array), text_as_html=np.concatenate(text_as_html), table_as_cells=np.concatenate(table_as_cells), ) @@ -123,12 +129,22 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: def iter_elements(self): """iter elements as one LayoutElement per iteration; this returns a generator and has less memory impact than the as_list method""" - for (x1, y1, x2, y2), text, prob, class_id, source, text_as_html, table_as_cells in zip( + for ( + (x1, y1, x2, y2), + text, + prob, + class_id, + source, + is_extracted, + text_as_html, + table_as_cells, + ) in zip( self.element_coords, self.texts, self.element_probs, self.element_class_ids, self.sources, + self.is_extracted_array, self.text_as_html, self.table_as_cells, ): @@ -145,6 +161,7 @@ def iter_elements(self): ), prob=None if np.isnan(prob) else prob, source=source, + is_extracted=is_extracted, text_as_html=text_as_html, table_as_cells=table_as_cells, ) @@ -157,13 +174,21 @@ def from_list(cls, elements: list): coords = np.empty((len_ele, 4), dtype=float) # text and probs can be Nones so use lists first then convert into array to avoid them being # filled as nan - texts, text_as_html, table_as_cells, sources, class_probs = [], [], [], [], [] + texts, text_as_html, table_as_cells, sources, is_extracted_array, class_probs = ( + [], + [], + [], + [], + [], + [], + ) class_types = np.empty((len_ele,), dtype="object") for i, element in enumerate(elements): coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] texts.append(element.text) sources.append(element.source) + is_extracted_array.append(element.is_extracted) text_as_html.append(element.text_as_html) table_as_cells.append(element.table_as_cells) class_probs.append(element.prob) @@ -179,6 +204,7 @@ def from_list(cls, elements: list): element_class_ids=class_ids, element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)), sources=np.array(sources), + is_extracted_array=np.array(is_extracted_array), text_as_html=np.array(text_as_html), table_as_cells=np.array(table_as_cells), ) @@ -201,6 +227,7 @@ def to_dict(self) -> dict: "type": self.type, "prob": self.prob, "source": self.source, + "is_extracted": self.is_extracted, } return out_dict @@ -211,7 +238,45 @@ def from_region(cls, region: TextRegion): type = region.type if hasattr(region, "type") else None prob = region.prob if hasattr(region, "prob") else None source = region.source if hasattr(region, "source") else None - return cls(text=text, source=source, type=type, prob=prob, bbox=region.bbox) + is_extracted = region.is_extracted if hasattr(region, "is_extracted") else None + return cls( + bbox=region.bbox, + text=text, + source=source, + is_extracted=is_extracted, + type=type, + prob=prob, + ) + + @classmethod + def from_coords( + cls, + x1: Union[int, float], + y1: Union[int, float], + x2: Union[int, float], + y2: Union[int, float], + text: Optional[str] = None, + source: Optional[Source] = None, + is_extracted: Optional[IsExtracted] = None, + type: Optional[str] = None, + prob: Optional[float] = None, + text_as_html: Optional[str] = None, + table_as_cells: Optional[str] = None, + **kwargs, + ) -> LayoutElement: + """Constructs a LayoutElement from coordinates.""" + bbox = Rectangle(x1, y1, x2, y2) + return cls( + text=text, + is_extracted=is_extracted, + type=type, + prob=prob, + source=source, + text_as_html=text_as_html, + table_as_cells=table_as_cells, + bbox=bbox, + **kwargs, + ) def separate(region_a: Rectangle, region_b: Rectangle): @@ -362,7 +427,7 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float = final_attrs: dict[str, Any] = { "element_class_id_map": elements.element_class_id_map, } - for attr in ("element_class_ids", "element_probs", "texts", "sources"): + for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"): if (original_attr := getattr(elements, attr)) is None: continue final_attrs[attr] = original_attr[sorted_by_area][mask][sorted_by_y1] @@ -438,7 +503,7 @@ def clean_layoutelements_for_class( final_coords = np.vstack([target_coords[mask], other_coords[other_mask]]) final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map} - for attr in ("element_class_ids", "element_probs", "texts", "sources"): + for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"): if (original_attr := getattr(elements, attr)) is None: continue final_attrs[attr] = np.concatenate(