diff --git a/learning_resources_search/api.py b/learning_resources_search/api.py index 8c6ee6fb01..fb7f96e0c9 100644 --- a/learning_resources_search/api.py +++ b/learning_resources_search/api.py @@ -922,19 +922,17 @@ def _qdrant_similar_results(doc, num_resources): list of dict: list of serialized resources """ - from learning_resources_search.indexing_api import qdrant_client + from learning_resources_search.indexing_api import qdrant_client, vector_point_id client = qdrant_client() return [ - hit.metadata - for hit in client.query( + hit.payload + for hit in client.query_points( collection_name=f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources", - query_text=( - f'{doc.get("title")} {doc.get("description")} ' - f'{doc.get("full_description")} {doc.get("content")}' - ), + query=vector_point_id(doc["readable_id"]), limit=num_resources, - ) + using=settings.QDRANT_SEARCH_VECTOR_NAME, + ).points ] diff --git a/learning_resources_search/indexing_api.py b/learning_resources_search/indexing_api.py index 9c606bfe61..ac91446815 100644 --- a/learning_resources_search/indexing_api.py +++ b/learning_resources_search/indexing_api.py @@ -4,6 +4,7 @@ import json import logging +import uuid from math import ceil from django.conf import settings @@ -143,6 +144,10 @@ def create_qdrand_collections(force_recreate): ) +def vector_point_id(readable_id): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, readable_id)) + + def embed_learning_resources(ids, resource_type): # update embeddings client = qdrant_client() @@ -168,7 +173,13 @@ def embed_learning_resources(ids, resource_type): f'{doc.get("full_description")} {doc.get("content")}' ) metadata.append(doc) - ids.append(doc["id"]) + if resource_type != CONTENT_FILE_TYPE: + vector_point_key = doc["readable_id"] + else: + vector_point_key = ( + f"{doc['key']}.{doc['run_readable_id']}.{doc['resource_readable_id']}" + ) + ids.append(vector_point_id(vector_point_key)) client.add( collection_name=collection_name, ids=ids, diff --git a/learning_resources_search/indexing_api_test.py b/learning_resources_search/indexing_api_test.py index cfae97d13e..f10126e6d9 100644 --- a/learning_resources_search/indexing_api_test.py +++ b/learning_resources_search/indexing_api_test.py @@ -13,6 +13,7 @@ from learning_resources.factories import ( ContentFileFactory, CourseFactory, + LearningResourceFactory, LearningResourceRunFactory, ) from learning_resources.models import ContentFile @@ -37,6 +38,7 @@ deindex_percolators, deindex_run_content_files, delete_orphaned_indexes, + embed_learning_resources, get_reindexing_alias_name, index_content_files, index_course_content_files, @@ -45,8 +47,10 @@ index_run_content_files, switch_indices, update_document_with_partial, + vector_point_id, ) from learning_resources_search.models import PercolateQuery +from learning_resources_search.serializers import serialize_bulk_content_files from learning_resources_search.utils import remove_child_queries from main.utils import chunks @@ -896,3 +900,31 @@ def test_clear_featured_rank(mocked_es, mocker, clear_all_greater_than): "query": query, }, ) + + +@pytest.mark.parametrize("content_type", ["learning_resource", "content_file"]) +def test_vector_point_id_used_for_embed(mocker, content_type): + # test the vector ids we generate for embedding resources and files + if content_type == "learning_resource": + resources = LearningResourceFactory.create_batch(5) + else: + resources = ContentFileFactory.create_batch(5) + mock_qdrant = mocker.patch("qdrant_client.QdrantClient") + mock_qdrant.query.return_value = [] + mocker.patch( + "learning_resources_search.indexing_api.qdrant_client", + return_value=mock_qdrant, + ) + + embed_learning_resources([resource.id for resource in resources], content_type) + + if content_type == "learning_resource": + point_ids = [vector_point_id(resource.readable_id) for resource in resources] + else: + point_ids = [ + vector_point_id( + f"{resource['key']}.{resource['run_readable_id']}.{resource['resource_readable_id']}" + ) + for resource in serialize_bulk_content_files([r.id for r in resources]) + ] + assert sorted(mock_qdrant.add.mock_calls[0].kwargs["ids"]) == sorted(point_ids) diff --git a/main/settings.py b/main/settings.py index 8b95c15ec9..be81dee2db 100644 --- a/main/settings.py +++ b/main/settings.py @@ -794,6 +794,9 @@ def get_all_config_keys(): name="QDRANT_COLLECTION_NAME", default="resource_embeddings" ) +QDRANT_SEARCH_VECTOR_NAME = get_string( + name="QDRANT_SEARCH_VECTOR_NAME", default="fast-bge-small-en" +) QDRANT_DENSE_MODEL = get_string( name="QDRANT_DENSE_MODEL", default="sentence-transformers/all-MiniLM-L6-v2"