Skip to content

Commit 77367f2

Browse files
support new dense vector quantization in 8.16
1 parent 0dd69f8 commit 77367f2

File tree

3 files changed

+120
-4
lines changed

3 files changed

+120
-4
lines changed

elasticsearch_dsl/field.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
389389
return float(data)
390390

391391

392-
class DenseVector(Float):
392+
class DenseVector(Field):
393393
name = "dense_vector"
394+
_coerce = True
394395

395396
def __init__(self, **kwargs: Any):
396-
kwargs["multi"] = True
397+
self._element_type = kwargs.get("element_type", "float")
398+
if self._element_type in ["float", "byte"]:
399+
kwargs["multi"] = True
397400
super().__init__(**kwargs)
398401

402+
def _deserialize(self, data: Any) -> Any:
403+
if self._element_type == "float":
404+
return float(data)
405+
elif self._element_type == "byte":
406+
return int(data)
407+
return data
408+
399409

400410
class SparseVector(Field):
401411
name = "sparse_vector"

tests/test_integration/_async/test_document.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
26+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
@@ -37,6 +37,7 @@
3737
Binary,
3838
Boolean,
3939
Date,
40+
DenseVector,
4041
Double,
4142
InnerDoc,
4243
Ip,
@@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
795796
"age": 45,
796797
"languages": ["es"],
797798
}
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_float_dense_vector(
803+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
804+
) -> None:
805+
if es_version >= (8, 16):
806+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
807+
808+
class Doc(AsyncDocument):
809+
float_vector: List[float] = mapped_field(DenseVector())
810+
811+
class Index:
812+
name = "vectors"
813+
814+
await Doc._index.delete(ignore_unavailable=True)
815+
await Doc.init()
816+
817+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
818+
await doc.save(refresh=True)
819+
820+
docs = await Doc.search().execute()
821+
assert len(docs) == 1
822+
assert docs[0].float_vector == doc.float_vector
823+
824+
825+
@pytest.mark.asyncio
826+
async def test_dense_vector(
827+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
828+
) -> None:
829+
if es_version < (8, 16):
830+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
831+
832+
class Doc(AsyncDocument):
833+
float_vector: List[float] = mapped_field(DenseVector())
834+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
835+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
836+
837+
class Index:
838+
name = "vectors"
839+
840+
await Doc._index.delete(ignore_unavailable=True)
841+
await Doc.init()
842+
843+
doc = Doc(
844+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
845+
)
846+
await doc.save(refresh=True)
847+
848+
docs = await Doc.search().execute()
849+
assert len(docs) == 1
850+
assert docs[0].float_vector == doc.float_vector
851+
assert docs[0].byte_vector == doc.byte_vector
852+
assert docs[0].bit_vector == doc.bit_vector

tests/test_integration/_sync/test_document.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
26+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
@@ -35,6 +35,7 @@
3535
Binary,
3636
Boolean,
3737
Date,
38+
DenseVector,
3839
Document,
3940
Double,
4041
InnerDoc,
@@ -789,3 +790,53 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
789790
"age": 45,
790791
"languages": ["es"],
791792
}
793+
794+
795+
@pytest.mark.sync
796+
def test_float_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
797+
if es_version >= (8, 16):
798+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
799+
800+
class Doc(Document):
801+
float_vector: List[float] = mapped_field(DenseVector())
802+
803+
class Index:
804+
name = "vectors"
805+
806+
Doc._index.delete(ignore_unavailable=True)
807+
Doc.init()
808+
809+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
810+
doc.save(refresh=True)
811+
812+
docs = Doc.search().execute()
813+
assert len(docs) == 1
814+
assert docs[0].float_vector == doc.float_vector
815+
816+
817+
@pytest.mark.sync
818+
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
819+
if es_version < (8, 16):
820+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
821+
822+
class Doc(Document):
823+
float_vector: List[float] = mapped_field(DenseVector())
824+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
825+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
826+
827+
class Index:
828+
name = "vectors"
829+
830+
Doc._index.delete(ignore_unavailable=True)
831+
Doc.init()
832+
833+
doc = Doc(
834+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
835+
)
836+
doc.save(refresh=True)
837+
838+
docs = Doc.search().execute()
839+
assert len(docs) == 1
840+
assert docs[0].float_vector == doc.float_vector
841+
assert docs[0].byte_vector == doc.byte_vector
842+
assert docs[0].bit_vector == doc.bit_vector

0 commit comments

Comments
 (0)