Skip to content

Commit 5f58f93

Browse files
authored
Semantic Chunking of Content Files (#2005)
* adding deps for semantic chunker * conforming to langchain embedding interface and adding ability to toggle semantic chunker * fixing tests * combine recursive and semantic chunker to stay within chunk size * fixing tests * updating defaults * adding test and fixes * doc updates
1 parent c5c98a8 commit 5f58f93

File tree

9 files changed

+246
-54
lines changed

9 files changed

+246
-54
lines changed

main/settings.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,35 @@ def get_all_config_keys():
849849
AI_MAX_BUDGET = get_float(name="AI_MAX_BUDGET", default=0.05)
850850
AI_ANON_LIMIT_MULTIPLIER = get_float(name="AI_ANON_LIMIT_MULTIPLIER", default=10.0)
851851
CONTENT_FILE_EMBEDDING_CHUNK_SIZE_OVERRIDE = get_int(
852-
name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=None
852+
name="CONTENT_FILE_EMBEDDING_CHUNK_SIZE", default=512
853853
)
854854
CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = get_int(
855855
name="CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP",
856-
default=200, # default that the tokenizer uses
856+
default=0, # default that the tokenizer uses
857857
)
858+
CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = get_bool(
859+
name="CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED", default=False
860+
)
861+
862+
SEMANTIC_CHUNKING_CONFIG = {
863+
"buffer_size": get_int(
864+
# Number of sentences to combine.
865+
name="SEMANTIC_CHUNKING_BUFFER_SIZE",
866+
default=1,
867+
),
868+
"breakpoint_threshold_type": get_string(
869+
# 'percentile', 'standard_deviation', 'interquartile', or 'gradient'
870+
name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_TYPE",
871+
default="percentile",
872+
),
873+
"breakpoint_threshold_amount": get_float(
874+
# value we use for breakpoint_threshold_type to filter outliers
875+
name="SEMANTIC_CHUNKING_BREAKPOINT_THRESHOLD_AMOUNT",
876+
default=None,
877+
),
878+
"number_of_chunks": get_int(
879+
# number of chunks to consider for merging
880+
name="SEMANTIC_CHUNKING_NUMBER_OF_CHUNKS",
881+
default=None,
882+
),
883+
}

poetry.lock

Lines changed: 92 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ tiktoken = "^0.8.0"
8989
llama-index = "^0.12.6"
9090
llama-index-llms-openai = "^0.3.12"
9191
llama-index-agent-openai = "^0.4.1"
92+
langchain-experimental = "^0.3.4"
93+
langchain-openai = "^0.3.2"
9294

9395

9496
[tool.poetry.group.dev.dependencies]

vector_search/conftest.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ class DummyEmbedEncoder(BaseEncoder):
1414
def __init__(self, model_name="dummy-embedding"):
1515
self.model_name = model_name
1616

17-
def encode(self, text: str) -> list: # noqa: ARG002
17+
def embed(self, text: str) -> list: # noqa: ARG002
1818
return np.random.random((10, 1))
1919

20-
def encode_batch(self, texts: list[str]) -> list[list[float]]:
20+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
2121
return np.random.random((10, len(texts)))
2222

2323

@@ -32,13 +32,37 @@ def _use_test_qdrant_settings(settings, mocker):
3232
settings.QDRANT_HOST = "https://test"
3333
settings.QDRANT_BASE_COLLECTION_NAME = "test"
3434
settings.CONTENT_FILE_EMBEDDING_CHUNK_OVERLAP = 0
35+
settings.CONTENT_FILE_EMBEDDING_SEMANTIC_CHUNKING_ENABLED = False
3536
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
37+
mocker.patch("vector_search.utils.SemanticChunker")
38+
3639
mock_qdrant.scroll.return_value = [
3740
[],
3841
None,
3942
]
40-
get_text_splitter_patch = mocker.patch("vector_search.utils._get_text_splitter")
41-
get_text_splitter_patch.return_value = RecursiveCharacterTextSplitter()
43+
get_text_splitter_patch = mocker.patch("vector_search.utils._chunk_documents")
44+
get_text_splitter_patch.return_value = (
45+
RecursiveCharacterTextSplitter().create_documents(
46+
texts=["test dociment"],
47+
metadatas=[
48+
{
49+
"run_title": "",
50+
"platform": "",
51+
"offered_by": "",
52+
"run_readable_id": "",
53+
"resource_readable_id": "",
54+
"content_type": "",
55+
"file_extension": "",
56+
"content_feature_type": "",
57+
"course_number": "",
58+
"file_type": "",
59+
"description": "",
60+
"key": "",
61+
"url": "",
62+
}
63+
],
64+
)
65+
)
4266
mock_qdrant.count.return_value = CountResult(count=10)
4367
mocker.patch(
4468
"vector_search.utils.qdrant_client",

vector_search/encoders/base.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from abc import ABC, abstractmethod
22

3+
from langchain_core.embeddings import Embeddings
34

4-
class BaseEncoder(ABC):
5+
6+
class BaseEncoder(Embeddings, ABC):
57
"""
6-
Base encoder class
8+
Base encoder class.
9+
Conforms to the langchain Embeddings interface
710
"""
811

912
def model_short_name(self):
@@ -17,21 +20,26 @@ def model_short_name(self):
1720
model_name = split_model_name[1]
1821
return model_name
1922

20-
def encode(self, text):
23+
def embed(self, text):
2124
"""
2225
Embed a single text
2326
"""
24-
return next(iter(self.encode_batch([text])))
27+
return next(iter(self.embed_documents([text])))
28+
29+
def dim(self):
30+
"""
31+
Return the dimension of the embeddings
32+
"""
33+
return len(self.embed("test"))
2534

2635
@abstractmethod
27-
def encode_batch(self, texts: list[str]) -> list[list[float]]:
36+
def embed_documents(self, documents):
2837
"""
29-
Embed multiple texts
38+
Embed a list of documents
3039
"""
31-
return [self.encode(text) for text in texts]
3240

33-
def dim(self):
41+
def embed_query(self, query):
3442
"""
35-
Return the dimension of the embeddings
43+
Embed a query
3644
"""
37-
return len(self.encode("test"))
45+
return self.embed(query)

vector_search/encoders/fastembed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def __init__(self, model_name="BAAI/bge-small-en-v1.5"):
1212
self.model_name = model_name
1313
self.model = TextEmbedding(model_name=model_name, lazy_load=True)
1414

15-
def encode_batch(self, texts: list[str]) -> list[list[float]]:
16-
return self.model.embed(texts)
15+
def embed_documents(self, documents):
16+
return list(self.model.embed(documents))
1717

1818
def dim(self):
1919
"""

vector_search/encoders/litellm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def __init__(self, model_name="text-embedding-3-small"):
2424
msg = f"Model {model_name} not found in tiktoken. defaulting to None"
2525
log.warning(msg)
2626

27-
def encode_batch(self, texts: list[str]) -> list[list[float]]:
28-
return [result["embedding"] for result in self.get_embedding(texts)["data"]]
27+
def embed_documents(self, documents):
28+
return [result["embedding"] for result in self.get_embedding(documents)["data"]]
2929

3030
def get_embedding(self, texts):
3131
if settings.LITELLM_CUSTOM_PROVIDER and settings.LITELLM_API_BASE:

0 commit comments

Comments
 (0)