Skip to content

Commit b5800d4

Browse files
Vector search example (#1778)
* Vector search example * addressed feedback * vectors example tests * skip vectors integration test for stacks < 8.11 --------- Co-authored-by: Quentin Pradet <[email protected]>
1 parent f56d3a5 commit b5800d4

File tree

6 files changed

+445
-0
lines changed

6 files changed

+445
-0
lines changed

examples/async/vectors.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
# Vector database example
20+
21+
Requirements:
22+
23+
$ pip install nltk sentence_transformers tqdm elasticsearch-dsl[async]
24+
25+
To run the example:
26+
27+
$ python vectors.py "text to search"
28+
29+
The index will be created automatically if it does not exist. Add
30+
`--recreate-index` to regenerate it.
31+
32+
The example dataset includes a selection of workplace documents. The
33+
following are good example queries to try out with this dataset:
34+
35+
$ python vectors.py "work from home"
36+
$ python vectors.py "vacation time"
37+
$ python vectors.py "can I bring a bird to work?"
38+
39+
When the index is created, the documents are split into short passages, and for
40+
each passage an embedding is generated using the open source
41+
"all-MiniLM-L6-v2" model. The documents that are returned as search results are
42+
those that have the highest scored passages. Add `--show-inner-hits` to the
43+
command to see individual passage results as well.
44+
"""
45+
46+
import argparse
47+
import asyncio
48+
import json
49+
import os
50+
from urllib.request import urlopen
51+
52+
import nltk
53+
from sentence_transformers import SentenceTransformer
54+
from tqdm import tqdm
55+
56+
from elasticsearch_dsl import (
57+
AsyncDocument,
58+
Date,
59+
DenseVector,
60+
InnerDoc,
61+
Keyword,
62+
Nested,
63+
Text,
64+
async_connections,
65+
)
66+
67+
DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
68+
MODEL_NAME = "all-MiniLM-L6-v2"
69+
70+
# initialize sentence tokenizer
71+
nltk.download("punkt", quiet=True)
72+
73+
74+
class Passage(InnerDoc):
75+
content = Text()
76+
embedding = DenseVector()
77+
78+
79+
class WorkplaceDoc(AsyncDocument):
80+
class Index:
81+
name = "workplace_documents"
82+
83+
name = Text()
84+
summary = Text()
85+
content = Text()
86+
created = Date()
87+
updated = Date()
88+
url = Keyword()
89+
category = Keyword()
90+
passages = Nested(Passage)
91+
92+
_model = None
93+
94+
@classmethod
95+
def get_embedding_model(cls):
96+
if cls._model is None:
97+
cls._model = SentenceTransformer(MODEL_NAME)
98+
return cls._model
99+
100+
def clean(self):
101+
# split the content into sentences
102+
passages = nltk.sent_tokenize(self.content)
103+
104+
# generate an embedding for each passage and save it as a nested document
105+
model = self.get_embedding_model()
106+
for passage in passages:
107+
self.passages.append(
108+
Passage(content=passage, embedding=list(model.encode(passage)))
109+
)
110+
111+
112+
async def create():
113+
114+
# create the index
115+
await WorkplaceDoc._index.delete(ignore_unavailable=True)
116+
await WorkplaceDoc.init()
117+
118+
# download the data
119+
dataset = json.loads(urlopen(DATASET_URL).read())
120+
121+
# import the dataset
122+
for data in tqdm(dataset, desc="Indexing documents..."):
123+
doc = WorkplaceDoc(
124+
name=data["name"],
125+
summary=data["summary"],
126+
content=data["content"],
127+
created=data.get("created_on"),
128+
updated=data.get("updated_at"),
129+
url=data["url"],
130+
category=data["category"],
131+
)
132+
await doc.save()
133+
134+
135+
async def search(query):
136+
model = WorkplaceDoc.get_embedding_model()
137+
return WorkplaceDoc.search().knn(
138+
field="passages.embedding",
139+
k=5,
140+
num_candidates=50,
141+
query_vector=list(model.encode(query)),
142+
inner_hits={"size": 2},
143+
)
144+
145+
146+
def parse_args():
147+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
148+
parser.add_argument(
149+
"--recreate-index", action="store_true", help="Recreate and populate the index"
150+
)
151+
parser.add_argument(
152+
"--show-inner-hits",
153+
action="store_true",
154+
help="Show results for individual passages",
155+
)
156+
parser.add_argument("query", action="store", help="The search query")
157+
return parser.parse_args()
158+
159+
160+
async def main():
161+
args = parse_args()
162+
163+
# initiate the default connection to elasticsearch
164+
async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
165+
166+
if args.recreate_index or not await WorkplaceDoc._index.exists():
167+
await create()
168+
169+
results = await search(args.query)
170+
171+
async for hit in results:
172+
print(
173+
f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]"
174+
)
175+
print(f"Summary: {hit.summary}")
176+
if args.show_inner_hits:
177+
for passage in hit.meta.inner_hits.passages:
178+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
179+
print("")
180+
181+
# close the connection
182+
await async_connections.get_connection().close()
183+
184+
185+
if __name__ == "__main__":
186+
asyncio.run(main())

examples/vectors.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
# Vector database example
20+
21+
Requirements:
22+
23+
$ pip install nltk sentence_transformers tqdm elasticsearch-dsl
24+
25+
To run the example:
26+
27+
$ python vectors.py "text to search"
28+
29+
The index will be created automatically if it does not exist. Add
30+
`--recreate-index` to regenerate it.
31+
32+
The example dataset includes a selection of workplace documents. The
33+
following are good example queries to try out with this dataset:
34+
35+
$ python vectors.py "work from home"
36+
$ python vectors.py "vacation time"
37+
$ python vectors.py "can I bring a bird to work?"
38+
39+
When the index is created, the documents are split into short passages, and for
40+
each passage an embedding is generated using the open source
41+
"all-MiniLM-L6-v2" model. The documents that are returned as search results are
42+
those that have the highest scored passages. Add `--show-inner-hits` to the
43+
command to see individual passage results as well.
44+
"""
45+
46+
import argparse
47+
import json
48+
import os
49+
from urllib.request import urlopen
50+
51+
import nltk
52+
from sentence_transformers import SentenceTransformer
53+
from tqdm import tqdm
54+
55+
from elasticsearch_dsl import (
56+
Date,
57+
DenseVector,
58+
Document,
59+
InnerDoc,
60+
Keyword,
61+
Nested,
62+
Text,
63+
connections,
64+
)
65+
66+
DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
67+
MODEL_NAME = "all-MiniLM-L6-v2"
68+
69+
# initialize sentence tokenizer
70+
nltk.download("punkt", quiet=True)
71+
72+
73+
class Passage(InnerDoc):
74+
content = Text()
75+
embedding = DenseVector()
76+
77+
78+
class WorkplaceDoc(Document):
79+
class Index:
80+
name = "workplace_documents"
81+
82+
name = Text()
83+
summary = Text()
84+
content = Text()
85+
created = Date()
86+
updated = Date()
87+
url = Keyword()
88+
category = Keyword()
89+
passages = Nested(Passage)
90+
91+
_model = None
92+
93+
@classmethod
94+
def get_embedding_model(cls):
95+
if cls._model is None:
96+
cls._model = SentenceTransformer(MODEL_NAME)
97+
return cls._model
98+
99+
def clean(self):
100+
# split the content into sentences
101+
passages = nltk.sent_tokenize(self.content)
102+
103+
# generate an embedding for each passage and save it as a nested document
104+
model = self.get_embedding_model()
105+
for passage in passages:
106+
self.passages.append(
107+
Passage(content=passage, embedding=list(model.encode(passage)))
108+
)
109+
110+
111+
def create():
112+
113+
# create the index
114+
WorkplaceDoc._index.delete(ignore_unavailable=True)
115+
WorkplaceDoc.init()
116+
117+
# download the data
118+
dataset = json.loads(urlopen(DATASET_URL).read())
119+
120+
# import the dataset
121+
for data in tqdm(dataset, desc="Indexing documents..."):
122+
doc = WorkplaceDoc(
123+
name=data["name"],
124+
summary=data["summary"],
125+
content=data["content"],
126+
created=data.get("created_on"),
127+
updated=data.get("updated_at"),
128+
url=data["url"],
129+
category=data["category"],
130+
)
131+
doc.save()
132+
133+
134+
def search(query):
135+
model = WorkplaceDoc.get_embedding_model()
136+
return WorkplaceDoc.search().knn(
137+
field="passages.embedding",
138+
k=5,
139+
num_candidates=50,
140+
query_vector=list(model.encode(query)),
141+
inner_hits={"size": 2},
142+
)
143+
144+
145+
def parse_args():
146+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
147+
parser.add_argument(
148+
"--recreate-index", action="store_true", help="Recreate and populate the index"
149+
)
150+
parser.add_argument(
151+
"--show-inner-hits",
152+
action="store_true",
153+
help="Show results for individual passages",
154+
)
155+
parser.add_argument("query", action="store", help="The search query")
156+
return parser.parse_args()
157+
158+
159+
def main():
160+
args = parse_args()
161+
162+
# initiate the default connection to elasticsearch
163+
connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
164+
165+
if args.recreate_index or not WorkplaceDoc._index.exists():
166+
create()
167+
168+
results = search(args.query)
169+
170+
for hit in results:
171+
print(
172+
f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]"
173+
)
174+
print(f"Summary: {hit.summary}")
175+
if args.show_inner_hits:
176+
for passage in hit.meta.inner_hits.passages:
177+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
178+
print("")
179+
180+
# close the connection
181+
connections.get_connection().close()
182+
183+
184+
if __name__ == "__main__":
185+
main()

setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
"pytest-asyncio",
4646
"pytz",
4747
"coverage",
48+
# the following three are used by the vectors example and its tests
49+
"nltk",
50+
"sentence_transformers",
51+
"tqdm",
4852
# Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5)
4953
"sphinx>2",
5054
"sphinx-rtd-theme>0.5",

0 commit comments

Comments
 (0)