diff --git a/text_extensions_for_pandas/array/test_token_span.py b/text_extensions_for_pandas/array/test_token_span.py index 613ebe0a..27af62d8 100644 --- a/text_extensions_for_pandas/array/test_token_span.py +++ b/text_extensions_for_pandas/array/test_token_span.py @@ -366,6 +366,19 @@ def test_as_frame(self): ) self.assertEqual(len(df), len(arr)) + def test_multi_doc(self): + arr1 = self._make_spans() + + text2 = "Hello world." + tokens2 = SpanArray(text2, [0, 6], [5, 11]) + arr2 = TokenSpanArray(tokens2, [0, 0], [1, 2]) + + series = pd.concat([pd.Series(arr1), pd.Series(arr2)]) + self.assertFalse(series.array.is_single_document) + self.assertEqual(2, len(series.array.split_by_document())) + self._assertArrayEquals(arr1, series.array.split_by_document()[0]) + self._assertArrayEquals(arr2, series.array.split_by_document()[1]) + @pytest.mark.skipif(LooseVersion(pa.__version__) < LooseVersion("2.0.0"), reason="Nested dictionaries only supported in Arrow >= 2.0.0") diff --git a/text_extensions_for_pandas/array/token_span.py b/text_extensions_for_pandas/array/token_span.py index 062624ca..82e11668 100644 --- a/text_extensions_for_pandas/array/token_span.py +++ b/text_extensions_for_pandas/array/token_span.py @@ -983,20 +983,20 @@ def is_single_document(self) -> bool: # More than one tokenization and at least one span. Check whether # every span has the same text. - # Find the first text ID that is not NA - first_text_id = None - for b, t in zip(self._begins, self._text_ids): + # Find the first span that is not NA + first_target_text = None + for b, t in zip(self._begin_tokens, self.target_text): if b != Span.NULL_OFFSET_VALUE: - first_text_id = t + first_target_text = t break - if first_text_id is None: + if first_target_text is None: # Special case: All NAs --> Zero documents return True return not np.any(np.logical_and( # Row is not null... - np.not_equal(self._begins, Span.NULL_OFFSET_VALUE), + np.not_equal(self._begin_tokens, Span.NULL_OFFSET_VALUE), # ...and is over a different text than the first row's text ID - np.not_equal(self._text_ids, first_text_id))) + np.not_equal(self.target_text, first_target_text))) def split_by_document(self) -> List["SpanArray"]: """