Skip to content

Commit e135ab8

Browse files
miguelgrinberggithub-actions[bot]
authored andcommitted
support new dense vector quantization in 8.16 (#1948)
* support new dense vector quantization in 8.16 * use 8.16 in CI builds (cherry picked from commit 5de355e)
1 parent 90a9e59 commit e135ab8

File tree

4 files changed

+123
-5
lines changed

4 files changed

+123
-5
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ jobs:
8383
"3.12",
8484
"3.13",
8585
]
86-
es-version: [8.0.0, 8.15.0]
86+
es-version: [8.0.0, 8.16.0]
8787

8888
steps:
8989
- name: Remove irrelevant software to free up disk space

elasticsearch_dsl/field.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
389389
return float(data)
390390

391391

392-
class DenseVector(Float):
392+
class DenseVector(Field):
393393
name = "dense_vector"
394+
_coerce = True
394395

395396
def __init__(self, **kwargs: Any):
396-
kwargs["multi"] = True
397+
self._element_type = kwargs.get("element_type", "float")
398+
if self._element_type in ["float", "byte"]:
399+
kwargs["multi"] = True
397400
super().__init__(**kwargs)
398401

402+
def _deserialize(self, data: Any) -> Any:
403+
if self._element_type == "float":
404+
return float(data)
405+
elif self._element_type == "byte":
406+
return int(data)
407+
return data
408+
399409

400410
class SparseVector(Field):
401411
name = "sparse_vector"

tests/test_integration/_async/test_document.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
26+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
@@ -37,6 +37,7 @@
3737
Binary,
3838
Boolean,
3939
Date,
40+
DenseVector,
4041
Double,
4142
InnerDoc,
4243
Ip,
@@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
795796
"age": 45,
796797
"languages": ["es"],
797798
}
799+
800+
801+
@pytest.mark.asyncio
802+
async def test_legacy_dense_vector(
803+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
804+
) -> None:
805+
if es_version >= (8, 16):
806+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
807+
808+
class Doc(AsyncDocument):
809+
float_vector: List[float] = mapped_field(DenseVector(dims=3))
810+
811+
class Index:
812+
name = "vectors"
813+
814+
await Doc._index.delete(ignore_unavailable=True)
815+
await Doc.init()
816+
817+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
818+
await doc.save(refresh=True)
819+
820+
docs = await Doc.search().execute()
821+
assert len(docs) == 1
822+
assert docs[0].float_vector == doc.float_vector
823+
824+
825+
@pytest.mark.asyncio
826+
async def test_dense_vector(
827+
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
828+
) -> None:
829+
if es_version < (8, 16):
830+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
831+
832+
class Doc(AsyncDocument):
833+
float_vector: List[float] = mapped_field(DenseVector())
834+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
835+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
836+
837+
class Index:
838+
name = "vectors"
839+
840+
await Doc._index.delete(ignore_unavailable=True)
841+
await Doc.init()
842+
843+
doc = Doc(
844+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
845+
)
846+
await doc.save(refresh=True)
847+
848+
docs = await Doc.search().execute()
849+
assert len(docs) == 1
850+
assert docs[0].float_vector == doc.float_vector
851+
assert docs[0].byte_vector == doc.byte_vector
852+
assert docs[0].bit_vector == doc.bit_vector

tests/test_integration/_sync/test_document.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datetime import datetime
2525
from ipaddress import ip_address
26-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
26+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union
2727

2828
import pytest
2929
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
@@ -35,6 +35,7 @@
3535
Binary,
3636
Boolean,
3737
Date,
38+
DenseVector,
3839
Document,
3940
Double,
4041
InnerDoc,
@@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
789790
"age": 45,
790791
"languages": ["es"],
791792
}
793+
794+
795+
@pytest.mark.sync
796+
def test_legacy_dense_vector(
797+
client: Elasticsearch, es_version: Tuple[int, ...]
798+
) -> None:
799+
if es_version >= (8, 16):
800+
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")
801+
802+
class Doc(Document):
803+
float_vector: List[float] = mapped_field(DenseVector(dims=3))
804+
805+
class Index:
806+
name = "vectors"
807+
808+
Doc._index.delete(ignore_unavailable=True)
809+
Doc.init()
810+
811+
doc = Doc(float_vector=[1.0, 1.2, 2.3])
812+
doc.save(refresh=True)
813+
814+
docs = Doc.search().execute()
815+
assert len(docs) == 1
816+
assert docs[0].float_vector == doc.float_vector
817+
818+
819+
@pytest.mark.sync
820+
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
821+
if es_version < (8, 16):
822+
pytest.skip("this test requires Elasticsearch 8.16 or newer")
823+
824+
class Doc(Document):
825+
float_vector: List[float] = mapped_field(DenseVector())
826+
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
827+
bit_vector: str = mapped_field(DenseVector(element_type="bit"))
828+
829+
class Index:
830+
name = "vectors"
831+
832+
Doc._index.delete(ignore_unavailable=True)
833+
Doc.init()
834+
835+
doc = Doc(
836+
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
837+
)
838+
doc.save(refresh=True)
839+
840+
docs = Doc.search().execute()
841+
assert len(docs) == 1
842+
assert docs[0].float_vector == doc.float_vector
843+
assert docs[0].byte_vector == doc.byte_vector
844+
assert docs[0].bit_vector == doc.bit_vector

0 commit comments

Comments
 (0)