Skip to content

Commit 302bf6c

Browse files
authored
Contentfile chunk embeddings (#1905)
* updating deps and adding method for getting token count * working contentfile chunk embeds * storing chunks and update initial resource record with embeddings from contentfile chunk * adding management command flag to generate embeds by id * fixing test * ensuring we stay under token size * removing full content from points * moving splitter to separate function * adding test for text splitter * adding more tests * fixing test * changing chunk key name * fix test setting: * fixing test
1 parent 78f745f commit 302bf6c

File tree

9 files changed

+884
-324
lines changed

9 files changed

+884
-324
lines changed

poetry.lock

Lines changed: 592 additions & 279 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ qdrant-client = {extras = ["fastembed"], version = "^1.12.0"}
8484
onnxruntime = "1.20.1"
8585
openai = "^1.55.3"
8686
litellm = "^1.53.5"
87+
langchain = "^0.3.11"
88+
tiktoken = "^0.8.0"
8789

8890

8991
[tool.poetry.group.dev.dependencies]

vector_search/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def encode_batch(self, texts: list[str]) -> list[list[float]]:
2323
@pytest.fixture(autouse=True)
2424
def _use_dummy_encoder(settings):
2525
settings.QDRANT_ENCODER = "vector_search.conftest.DummyEmbedEncoder"
26+
settings.QDRANT_DENSE_MODEL = None
2627

2728

2829
@pytest.fixture(autouse=True)

vector_search/encoders/litellm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
import logging
2+
3+
import tiktoken
14
from litellm import embedding
25

36
from vector_search.encoders.base import BaseEncoder
47

8+
log = logging.getLogger()
9+
510

611
class LiteLLMEncoder(BaseEncoder):
712
"""
813
LiteLLM encoder
914
"""
1015

16+
token_encoding_name = "cl100k_base" # noqa: S105
17+
1118
def __init__(self, model_name="text-embedding-3-small"):
1219
self.model_name = model_name
20+
try:
21+
self.token_encoding_name = tiktoken.encoding_name_for_model(model_name)
22+
except KeyError:
23+
msg = f"Model {model_name} not found in tiktoken. defaulting to cl100k_base"
24+
log.warning(msg)
1325

1426
def encode_batch(self, texts: list[str]) -> list[list[float]]:
1527
return [

vector_search/management/commands/generate_embeddings.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from learning_resources_search.constants import LEARNING_RESOURCE_TYPES
66
from main.utils import clear_search_cache, now_in_utc
7-
from vector_search.tasks import start_embed_resources
7+
from vector_search.tasks import embed_learning_resources_by_id, start_embed_resources
88
from vector_search.utils import (
99
create_qdrand_collections,
1010
)
@@ -30,6 +30,12 @@ def add_arguments(self, parser):
3030
help="Embed all resource types (including content files)",
3131
)
3232

33+
parser.add_argument(
34+
"--resource-ids",
35+
dest="resource-ids",
36+
help="Embed a specific set of reesources (overrides the --all flag)",
37+
)
38+
3339
parser.add_argument(
3440
"--skip-contentfiles",
3541
dest="skip_content_files",
@@ -49,7 +55,7 @@ def add_arguments(self, parser):
4955
def handle(self, *args, **options): # noqa: ARG002
5056
"""Embed all LEARNING_RESOURCE_TYPES"""
5157

52-
if options["all"]:
58+
if options["all"] or options["resource-ids"]:
5359
indexes_to_update = list(LEARNING_RESOURCE_TYPES)
5460
else:
5561
indexes_to_update = list(
@@ -66,9 +72,18 @@ def handle(self, *args, **options): # noqa: ARG002
6672
return
6773
if options["recreate_collections"]:
6874
create_qdrand_collections(force_recreate=True)
69-
task = start_embed_resources.delay(
70-
indexes_to_update, skip_content_files=options["skip_content_files"]
71-
)
75+
if options["resource-ids"]:
76+
task = embed_learning_resources_by_id.delay(
77+
[
78+
int(resource_id)
79+
for resource_id in options["resource-ids"].split(",")
80+
],
81+
skip_content_files=options["skip_content_files"],
82+
)
83+
else:
84+
task = start_embed_resources.delay(
85+
indexes_to_update, skip_content_files=options["skip_content_files"]
86+
)
7287
self.stdout.write(
7388
f"Started celery task {task} to index content for the following"
7489
f" Types to embed: {indexes_to_update}"

vector_search/tasks.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CONTENT_FILE_TYPE,
1717
COURSE_TYPE,
1818
LEARNING_PATH_TYPE,
19+
LEARNING_RESOURCE_TYPES,
1920
PODCAST_EPISODE_TYPE,
2021
PODCAST_TYPE,
2122
PROGRAM_TYPE,
@@ -67,7 +68,7 @@ def generate_embeddings(ids, resource_type):
6768
@app.task(bind=True)
6869
def start_embed_resources(self, indexes, skip_content_files):
6970
"""
70-
Celery task to embed learning resources
71+
Celery task to embed all learning resources for given indexes
7172
7273
Args:
7374
indexes (list of str): resource types to embed
@@ -152,6 +153,73 @@ def start_embed_resources(self, indexes, skip_content_files):
152153
return self.replace(celery.chain(*index_tasks))
153154

154155

156+
@app.task(bind=True)
157+
def embed_learning_resources_by_id(self, ids, skip_content_files):
158+
"""
159+
Celery task to embed specific resources
160+
161+
Args:
162+
ids (list of int): list of resource ids to embed
163+
skip_content_files (bool): whether to skip embedding content files
164+
"""
165+
index_tasks = []
166+
if not all([settings.QDRANT_HOST, settings.QDRANT_BASE_COLLECTION_NAME]):
167+
log.warning(
168+
"skipping. start_embed_resources called without setting "
169+
"QDRANT_HOST and QDRANT_BASE_COLLECTION_NAME"
170+
)
171+
return None
172+
resources = LearningResource.objects.filter(
173+
id__in=ids,
174+
published=True,
175+
)
176+
try:
177+
for resource_type in LEARNING_RESOURCE_TYPES:
178+
resources = resources.filter(resource_type=resource_type)
179+
180+
[
181+
index_tasks.append(
182+
generate_embeddings.si(
183+
chunk_ids,
184+
resource_type,
185+
)
186+
)
187+
for chunk_ids in chunks(
188+
resources.order_by("id").values_list("id", flat=True),
189+
chunk_size=settings.OPENSEARCH_INDEXING_CHUNK_SIZE,
190+
)
191+
]
192+
if not skip_content_files and resource_type == COURSE_TYPE:
193+
for course in resources.filter(
194+
etl_source__in=RESOURCE_FILE_ETL_SOURCES
195+
).order_by("id"):
196+
index_tasks = index_tasks + [
197+
generate_embeddings.si(
198+
content_ids,
199+
CONTENT_FILE_TYPE,
200+
)
201+
for content_ids in chunks(
202+
ContentFile.objects.filter(
203+
run__learning_resource_id=course.id,
204+
published=True,
205+
run__published=True,
206+
)
207+
.order_by("id")
208+
.values_list("id", flat=True),
209+
chunk_size=settings.OPENSEARCH_DOCUMENT_INDEXING_CHUNK_SIZE,
210+
)
211+
]
212+
except: # noqa: E722
213+
error = "start_embed_resources threw an error"
214+
log.exception(error)
215+
return error
216+
217+
# Use self.replace so that code waiting on this task will also wait on the embedding
218+
# and finish tasks
219+
220+
return self.replace(celery.chain(*index_tasks))
221+
222+
155223
@app.task(bind=True)
156224
def embed_new_learning_resources(self):
157225
"""
@@ -165,13 +233,16 @@ def embed_new_learning_resources(self):
165233
created_on__gt=since,
166234
).exclude(resource_type=CONTENT_FILE_TYPE)
167235
filtered_resources = filter_existing_qdrant_points(new_learning_resources)
168-
embed_tasks = celery.group(
169-
[
170-
generate_embeddings.si(ids, COURSE_TYPE)
171-
for ids in chunks(
172-
filtered_resources.order_by("id").values_list("id", flat=True),
173-
chunk_size=settings.OPENSEARCH_INDEXING_CHUNK_SIZE,
174-
)
175-
]
176-
)
236+
for resource_type in LEARNING_RESOURCE_TYPES:
237+
embed_tasks = celery.group(
238+
[
239+
generate_embeddings.si(ids, resource_type)
240+
for ids in chunks(
241+
filtered_resources.filter(resource_type=resource_type).values_list(
242+
"id", flat=True
243+
),
244+
chunk_size=settings.OPENSEARCH_INDEXING_CHUNK_SIZE,
245+
)
246+
]
247+
)
177248
return self.replace(embed_tasks)

vector_search/tasks_test.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@
33
import pytest
44
from django.conf import settings
55

6-
from learning_resources.etl.constants import ETLSource
6+
from learning_resources.etl.constants import RESOURCE_FILE_ETL_SOURCES, ETLSource
77
from learning_resources.factories import (
88
ContentFileFactory,
99
CourseFactory,
1010
LearningResourceFactory,
11+
LearningResourceRunFactory,
1112
ProgramFactory,
1213
)
1314
from learning_resources.models import LearningResource
1415
from learning_resources_search.constants import (
1516
COURSE_TYPE,
1617
)
1718
from main.utils import now_in_utc
18-
from vector_search.tasks import embed_new_learning_resources, start_embed_resources
19+
from vector_search.tasks import (
20+
embed_learning_resources_by_id,
21+
embed_new_learning_resources,
22+
start_embed_resources,
23+
)
1924

2025
pytestmark = pytest.mark.django_db
2126

@@ -138,3 +143,38 @@ def test_embed_new_learning_resources(mocker, mocked_celery):
138143

139144
embedded_ids = generate_embeddings_mock.si.mock_calls[0].args[0]
140145
assert sorted(daily_resource_ids) == sorted(embedded_ids)
146+
147+
148+
def test_embed_learning_resources_by_id(mocker, mocked_celery):
149+
"""
150+
embed_learning_resources_by_id should generate embeddings for resources
151+
based the ids passed as well as associated contentfiles
152+
"""
153+
mocker.patch("vector_search.tasks.load_course_blocklist", return_value=[])
154+
155+
resources = LearningResourceFactory.create_batch(
156+
4,
157+
resource_type=COURSE_TYPE,
158+
etl_source=RESOURCE_FILE_ETL_SOURCES[0],
159+
published=True,
160+
)
161+
162+
resource_ids = [resource.id for resource in resources]
163+
164+
generate_embeddings_mock = mocker.patch(
165+
"vector_search.tasks.generate_embeddings", autospec=True
166+
)
167+
content_ids = []
168+
for resource in resources:
169+
cf = ContentFileFactory.create(
170+
run=LearningResourceRunFactory.create(learning_resource=resource)
171+
)
172+
content_ids.append(cf.id)
173+
174+
with pytest.raises(mocked_celery.replace_exception_class):
175+
embed_learning_resources_by_id.delay(resource_ids, skip_content_files=False)
176+
for mock_call in generate_embeddings_mock.si.mock_calls[1:]:
177+
assert mock_call.args[0][0] in content_ids
178+
assert mock_call.args[1] == "content_file"
179+
embedded_resource_ids = generate_embeddings_mock.si.mock_calls[0].args[0]
180+
assert sorted(resource_ids) == sorted(embedded_resource_ids)

0 commit comments

Comments
 (0)