Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bindings/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@

---

# Changes in Version 1.9.0 (2025/XX/XX)

- Providing a schema now enforces strict type adherence for data.
If a result contains a field whose value does not match the schema's type for that field, a TypeError will be raised.
Note that ``NaN`` is a valid type for all fields.
To suppress these errors and instead silently convert such mismatches to ``NaN``, pass the ``allow_invalid=True`` argument to your ``pymongoarrow`` API call.
For example, a result with a field of type ``int`` but with a string value will now raise a TypeError,
unless ``allow_invalid=True`` is passed, in which case the result's field will have a value of ``NaN``.

# Changes in Version 1.8.0 (2025/05/12)

- Add support for PyArrow 20.0.
Expand Down
70 changes: 54 additions & 16 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)


def find_arrow_all(collection, query, *, schema=None, **kwargs):
def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`pyarrow.Table` instance.

Expand All @@ -83,14 +83,18 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.

:Returns:
An instance of class:`pyarrow.Table`.
"""
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)

for opt in ("cursor_type",):
if kwargs.pop(opt, None):
Expand All @@ -110,7 +114,7 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
return context.finish()


def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`pyarrow.Table` instance.

Expand All @@ -121,14 +125,18 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`pyarrow.Table`.
"""
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
context = PyMongoArrowContext(
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
)

if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
msg = (
Expand Down Expand Up @@ -165,7 +173,7 @@ def _arrow_to_pandas(arrow_table):
return arrow_table.to_pandas(split_blocks=True, self_destruct=True)


def find_pandas_all(collection, query, *, schema=None, **kwargs):
def find_pandas_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`pandas.DataFrame` instance.

Expand All @@ -176,17 +184,21 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.

:Returns:
An instance of class:`pandas.DataFrame`.
"""
return _arrow_to_pandas(find_arrow_all(collection, query, schema=schema, **kwargs))
return _arrow_to_pandas(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
)


def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_pandas_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`pandas.DataFrame` instance.

Expand All @@ -197,14 +209,20 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`pandas.DataFrame`.
"""
return _arrow_to_pandas(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
return _arrow_to_pandas(
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
)
)


def _arrow_to_numpy(arrow_table, schema=None):
Expand All @@ -227,7 +245,7 @@ def _arrow_to_numpy(arrow_table, schema=None):
return container


def find_numpy_all(collection, query, *, schema=None, **kwargs):
def find_numpy_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`dict` instance whose keys are field names and values are
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
Expand All @@ -239,6 +257,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -255,10 +275,13 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
:Returns:
An instance of :class:`dict`.
"""
return _arrow_to_numpy(find_arrow_all(collection, query, schema=schema, **kwargs), schema)
return _arrow_to_numpy(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs),
schema,
)


def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_numpy_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`dict` instance whose keys are field names and values are
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
Expand All @@ -270,6 +293,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.
Expand All @@ -287,7 +312,10 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
An instance of :class:`dict`.
"""
return _arrow_to_numpy(
aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs), schema
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
),
schema,
)


Expand Down Expand Up @@ -326,7 +354,7 @@ def _arrow_to_polars(arrow_table: pa.Table):
return pl.from_arrow(arrow_table_without_extensions)


def find_polars_all(collection, query, *, schema=None, **kwargs):
def find_polars_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of a find query as a
:class:`polars.DataFrame` instance.

Expand All @@ -337,6 +365,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``find`` operation.
Expand All @@ -346,10 +376,12 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):

.. versionadded:: 1.3
"""
return _arrow_to_polars(find_arrow_all(collection, query, schema=schema, **kwargs))
return _arrow_to_polars(
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
)


def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
def aggregate_polars_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
"""Method that returns the results of an aggregation pipeline as a
:class:`polars.DataFrame` instance.

Expand All @@ -360,14 +392,20 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
If the schema is not given, it will be inferred using the data in the
result set.
- `allow_invalid` (optional): If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.

Additional keyword-arguments passed to this method will be passed
directly to the underlying ``aggregate`` operation.

:Returns:
An instance of class:`polars.DataFrame`.
"""
return _arrow_to_polars(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
return _arrow_to_polars(
aggregate_arrow_all(
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
)
)


def _transform_bwe(bwe, offset):
Expand Down
8 changes: 6 additions & 2 deletions bindings/python/pymongoarrow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
class PyMongoArrowContext:
"""A context for converting BSON-formatted data to an Arrow Table."""

def __init__(self, schema, codec_options=None):
def __init__(self, schema, codec_options=None, allow_invalid=False):
"""Initialize the context.

:Parameters:
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
- `builder_map`: Mapping of utf-8-encoded field names to
:class:`~pymongoarrow.builders._BuilderBase` instances.
- `allow_invalid`: If set to ``True``,
results will have all fields that do not conform to the schema silently converted to NaN.
"""
self.schema = schema
if self.schema is None and codec_options is not None:
Expand All @@ -40,7 +42,9 @@ def __init__(self, schema, codec_options=None):
# Delayed import to prevent import errors for unbuilt library.
from pymongoarrow.lib import BuilderManager

self.manager = BuilderManager(schema_map, self.schema is not None, self.tzinfo)
self.manager = BuilderManager(
schema_map, self.schema is not None, self.tzinfo, allow_invalid=allow_invalid
)
self.schema_map = schema_map

def process_bson_stream(self, stream):
Expand Down
Loading
Loading