Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ test = [
# type checker and type stubs for static type checking
typing = [
"mypy >= 1.15.0",
"typing-extensions >= 4.13.2",
"types-pytz >= 2025.2.0.20250326",
{include-group = "dep-typing-extensions"},
{include-group = "dep-project-dependencies"},
# tests are also type-checked
{include-group = "test"},
Expand All @@ -124,7 +124,11 @@ typing = [
{include-group = "benchkit"},
]
# generate documentation
docs = ["Sphinx >= 8.1.3"]
docs = [
"Sphinx >= 8.1.3",
{include-group = "dep-typing-extensions"},
{include-group = "dep-project-dependencies"},
]
# running the testkit backend
testkit = [
{include-group = "dep-freezegun"},
Expand All @@ -141,6 +145,7 @@ release = [
]

# single dependencies and other include-groups (not really meant to be installed as a group, but to avoid duplication)
dep-typing-extensions = ["typing-extensions >= 4.13.2"]
dep-freezegun = ["freezegun >= 1.5.1"]
dep-project-dependencies = [
"pytz",
Expand Down
84 changes: 46 additions & 38 deletions src/neo4j/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,18 @@
from enum import Enum as _Enum

from . import _typing as _t
from ._optional_deps import (
np as _np,
pa as _pa,
)


if _t.TYPE_CHECKING:
# "Why?", I hear you ask. Because sphinx of course.
# This beautiful construct helps sphinx to properly resolve the type hints.
import numpy as _np
import pyarrow as _pa
else:
from ._optional_deps import (
np as _np,
pa as _pa,
)


if False:
Expand All @@ -41,10 +49,6 @@
_swap_endian_unchecked_rust = None
_vec_rust = None

if _t.TYPE_CHECKING:
import numpy # type: ignore[import]
import pyarrow # type: ignore[import]


__all__ = [
"Vector",
Expand Down Expand Up @@ -84,14 +88,18 @@ class Vector:
Use an iterable of floats or an iterable of ints to construct the
vector from native Python values.
The ``dtype`` parameter is required.
See also: :meth:`.from_native`.
* ``bytes``, ``bytearray``: Use raw bytes to construct the vector.
The ``dtype`` parameter is required and ``byteorder`` is optional.
* ``numpy.ndarray``: Use a numpy array to construct the vector.
No further parameters are accepted.
See also: :meth:`.from_numpy`.
* ``pyarrow.Array``: Use a pyarrow array to construct the vector.
No further parameters are accepted.
See also: :meth:`.from_pyarrow`.
:param dtype: The type of the vector.
See :attr:`.dtype` for currently supported inner data types.
See :class:`.VectorDType` for currently supported inner data types.
See also :attr:`.dtype`.

This parameter is required if ``data`` is of type :class:`bytes`,
:class:`bytearray`, ``Iterable[float]``, or ``Iterable[int]``.
Expand Down Expand Up @@ -163,10 +171,10 @@ def __init__(
) -> None: ...

@_t.overload
def __init__(self, data: numpy.ndarray, /) -> None: ...
def __init__(self, data: _np.ndarray, /) -> None: ...

@_t.overload
def __init__(self, data: pyarrow.Array, /) -> None: ...
def __init__(self, data: _pa.Array, /) -> None: ...

def __init__(self, data, *args, **kwargs) -> None:
if isinstance(data, (bytes, bytearray)):
Expand Down Expand Up @@ -373,15 +381,15 @@ def to_native(self) -> list[object]:
return self._inner.to_native()

@classmethod
def from_numpy(cls, data: numpy.ndarray, /) -> _t.Self:
def from_numpy(cls, data: _np.ndarray, /) -> _t.Self:
"""
Create a Vector instance from a numpy array.

:param data: The numpy array to create the vector from.
The array must be one-dimensional and have a dtype that is
supported by Neo4j vectors: ``float64``, ``float32``,
``int64``, ``int32``, ``int16``, or ``int8``.
See also :attr:`.dtype`.
See also :class:`.VectorDType`.

:raises ValueError:
* If the dtype is not supported.
Expand All @@ -394,7 +402,7 @@ def from_numpy(cls, data: numpy.ndarray, /) -> _t.Self:
obj._set_numpy(data)
return obj

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
"""
Convert the vector to a numpy array.

Expand All @@ -407,7 +415,7 @@ def to_numpy(self) -> numpy.ndarray:
"""
return self._inner.to_numpy()

def _set_numpy(self, data: numpy.ndarray, /) -> None:
def _set_numpy(self, data: _np.ndarray, /) -> None:
if data.ndim != 1:
raise ValueError("Data must be one-dimensional")
type_: type[_InnerVector]
Expand All @@ -429,18 +437,17 @@ def _set_numpy(self, data: numpy.ndarray, /) -> None:
self._inner = type_.from_numpy(data)

@classmethod
def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self:
def from_pyarrow(cls, data: _pa.Array, /) -> _t.Self:
"""
Create a Vector instance from a pyarrow array.

:param data: The pyarrow array to create the vector from.
The array must have a type that is supported by Neo4j.
See also :attr:`.dtype`.

PyArrow stores data in little endian. Therefore, the byte-order needs
to be swapped. If ``neo4j-rust-ext`` or ``numpy`` is installed, it will
be used to speed up the byte flipping.

:param data: The pyarrow array to create the vector from.
The array must have a type that is supported by Neo4j.
See also :class:`.VectorDType`.
:raises ValueError:
* If the array's type is not supported.
* If the array contains null values.
Expand All @@ -452,7 +459,7 @@ def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self:
obj._set_pyarrow(data)
return obj

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
"""
Convert the vector to a pyarrow array.

Expand All @@ -462,7 +469,7 @@ def to_pyarrow(self) -> pyarrow.Array:
"""
return self._inner.to_pyarrow()

def _set_pyarrow(self, data: pyarrow.Array, /) -> None:
def _set_pyarrow(self, data: _pa.Array, /) -> None:
import pyarrow

type_: type[_InnerVector]
Expand Down Expand Up @@ -581,6 +588,7 @@ def _swap_endian(type_size: int, data: bytes, /) -> bytes:


def _swap_endian_unchecked_np(type_size: int, data: bytes, /) -> bytes:
dtype: _np.dtype
match type_size:
case 2:
dtype = _np.dtype("<i2")
Expand Down Expand Up @@ -727,18 +735,18 @@ def from_native(cls, data: _t.Iterable[object], /) -> _t.Self: ...
def to_native(self) -> list[object]: ...

@classmethod
def from_numpy(cls, data: numpy.ndarray, /) -> _t.Self:
def from_numpy(cls, data: _np.ndarray, /) -> _t.Self:
if data.dtype.byteorder == "<" or (
data.dtype.byteorder == "=" and _sys.byteorder == "little"
):
data = data.byteswap()
return cls(data.tobytes())

@_abc.abstractmethod
def to_numpy(self) -> numpy.ndarray: ...
def to_numpy(self) -> _np.ndarray: ...

@classmethod
def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self:
def from_pyarrow(cls, data: _pa.Array, /) -> _t.Self:
width = data.type.byte_width
assert cls.size == width
if _pa.compute.count(data, mode="only_null").as_py():
Expand All @@ -750,7 +758,7 @@ def from_pyarrow(cls, data: pyarrow.Array, /) -> _t.Self:
return cls(bytes(buffer), byteorder=_sys.byteorder)

@_abc.abstractmethod
def to_pyarrow(self) -> pyarrow.Array: ...
def to_pyarrow(self) -> _pa.Array: ...


class _InnerVectorFloat(_InnerVector, _abc.ABC):
Expand Down Expand Up @@ -822,12 +830,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">f8"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down Expand Up @@ -897,12 +905,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">f4"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down Expand Up @@ -997,12 +1005,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">i8"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down Expand Up @@ -1090,12 +1098,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">i4"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down Expand Up @@ -1183,12 +1191,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">i2"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down Expand Up @@ -1276,12 +1284,12 @@ def _to_native_py(self) -> list[object]:
else:
to_native = _to_native_py

def to_numpy(self) -> numpy.ndarray:
def to_numpy(self) -> _np.ndarray:
import numpy

return numpy.frombuffer(self.data, dtype=numpy.dtype(">i1"))

def to_pyarrow(self) -> pyarrow.Array:
def to_pyarrow(self) -> _pa.Array:
import pyarrow

buffer = pyarrow.py_buffer(self.data_le)
Expand Down