Skip to content

Consistent qdrant point ids #1839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions learning_resources_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,19 +922,17 @@ def _qdrant_similar_results(doc, num_resources):
list of dict:
list of serialized resources
"""
from learning_resources_search.indexing_api import qdrant_client
from learning_resources_search.indexing_api import qdrant_client, vector_point_id

client = qdrant_client()
return [
hit.metadata
for hit in client.query(
hit.payload
for hit in client.query_points(
collection_name=f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources",
query_text=(
f'{doc.get("title")} {doc.get("description")} '
f'{doc.get("full_description")} {doc.get("content")}'
),
query=vector_point_id(doc["readable_id"]),
limit=num_resources,
)
using=settings.QDRANT_SEARCH_VECTOR_NAME,
).points
]


Expand Down
13 changes: 12 additions & 1 deletion learning_resources_search/indexing_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import logging
import uuid
from math import ceil

from django.conf import settings
Expand Down Expand Up @@ -143,6 +144,10 @@ def create_qdrand_collections(force_recreate):
)


def vector_point_id(readable_id):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, readable_id))


def embed_learning_resources(ids, resource_type):
# update embeddings
client = qdrant_client()
Expand All @@ -168,7 +173,13 @@ def embed_learning_resources(ids, resource_type):
f'{doc.get("full_description")} {doc.get("content")}'
)
metadata.append(doc)
ids.append(doc["id"])
if resource_type != CONTENT_FILE_TYPE:
vector_point_key = doc["readable_id"]
else:
vector_point_key = (
f"{doc['key']}.{doc['run_readable_id']}.{doc['resource_readable_id']}"
)
ids.append(vector_point_id(vector_point_key))
client.add(
collection_name=collection_name,
ids=ids,
Expand Down
32 changes: 32 additions & 0 deletions learning_resources_search/indexing_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from learning_resources.factories import (
ContentFileFactory,
CourseFactory,
LearningResourceFactory,
LearningResourceRunFactory,
)
from learning_resources.models import ContentFile
Expand All @@ -37,6 +38,7 @@
deindex_percolators,
deindex_run_content_files,
delete_orphaned_indexes,
embed_learning_resources,
get_reindexing_alias_name,
index_content_files,
index_course_content_files,
Expand All @@ -45,8 +47,10 @@
index_run_content_files,
switch_indices,
update_document_with_partial,
vector_point_id,
)
from learning_resources_search.models import PercolateQuery
from learning_resources_search.serializers import serialize_bulk_content_files
from learning_resources_search.utils import remove_child_queries
from main.utils import chunks

Expand Down Expand Up @@ -896,3 +900,31 @@ def test_clear_featured_rank(mocked_es, mocker, clear_all_greater_than):
"query": query,
},
)


@pytest.mark.parametrize("content_type", ["learning_resource", "content_file"])
def test_vector_point_id_used_for_embed(mocker, content_type):
# test the vector ids we generate for embedding resources and files
if content_type == "learning_resource":
resources = LearningResourceFactory.create_batch(5)
else:
resources = ContentFileFactory.create_batch(5)
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
mock_qdrant.query.return_value = []
mocker.patch(
"learning_resources_search.indexing_api.qdrant_client",
return_value=mock_qdrant,
)

embed_learning_resources([resource.id for resource in resources], content_type)

if content_type == "learning_resource":
point_ids = [vector_point_id(resource.readable_id) for resource in resources]
else:
point_ids = [
vector_point_id(
f"{resource['key']}.{resource['run_readable_id']}.{resource['resource_readable_id']}"
)
for resource in serialize_bulk_content_files([r.id for r in resources])
]
assert sorted(mock_qdrant.add.mock_calls[0].kwargs["ids"]) == sorted(point_ids)
3 changes: 3 additions & 0 deletions main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,9 @@ def get_all_config_keys():
name="QDRANT_COLLECTION_NAME", default="resource_embeddings"
)

QDRANT_SEARCH_VECTOR_NAME = get_string(
name="QDRANT_SEARCH_VECTOR_NAME", default="fast-bge-small-en"
)

QDRANT_DENSE_MODEL = get_string(
name="QDRANT_DENSE_MODEL", default="sentence-transformers/all-MiniLM-L6-v2"
Expand Down