Skip to content

Commit 72c6ad2

Browse files
committed
Add Amazon Bedrock Text vectorizer (#143)
1 parent f280c64 commit 72c6ad2

File tree

9 files changed

+621
-59
lines changed

9 files changed

+621
-59
lines changed

.github/workflows/run_tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ jobs:
6666
AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}}
6767
AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}}
6868
OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}}
69+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
70+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
6971
run: |
7072
poetry run test-cov
7173

conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def gcp_location():
7171
def gcp_project_id():
7272
return os.getenv("GCP_PROJECT_ID")
7373

74+
@pytest.fixture
75+
def aws_credentials():
76+
return {
77+
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
78+
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
79+
"aws_region": os.getenv("AWS_REGION", "us-east-1")
80+
}
7481

7582
@pytest.fixture
7683
def sample_data():

docs/api/vectorizer.rst

+10
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ CohereTextVectorizer
6161
:show-inheritance:
6262
:members:
6363

64+
BedrockTextVectorizer
65+
=====================
66+
67+
.. _bedrocktextvectorizer_api:
68+
69+
.. currentmodule:: redisvl.utils.vectorize.text.bedrock
70+
71+
.. autoclass:: BedrockTextVectorizer
72+
:show-inheritance:
73+
:members:
6474

6575
CustomTextVectorizer
6676
====================

docs/user_guide/vectorizers_04.ipynb

+75-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"3. Vertex AI\n",
1414
"4. Cohere\n",
1515
"5. Mistral AI\n",
16-
"6. Bringing your own vectorizer\n",
16+
"6. Amazon Bedrock\n",
17+
"7. Bringing your own vectorizer\n",
1718
"\n",
1819
"Before running this notebook, be sure to\n",
1920
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
@@ -541,6 +542,76 @@
541542
"# print(test[:10])"
542543
]
543544
},
545+
{
546+
"cell_type": "markdown",
547+
"metadata": {},
548+
"source": [
549+
"### Amazon Bedrock\n",
550+
"\n",
551+
"Amazon Bedrock provides fully managed foundation models for text embeddings. Install the required dependencies:\n",
552+
"\n",
553+
"```bash\n",
554+
"pip install 'redisvl[bedrock]' # Installs boto3\n",
555+
"```"
556+
]
557+
},
558+
{
559+
"cell_type": "markdown",
560+
"metadata": {},
561+
"source": [
562+
"#### Configure AWS credentials:"
563+
]
564+
},
565+
{
566+
"cell_type": "code",
567+
"execution_count": null,
568+
"metadata": {},
569+
"outputs": [],
570+
"source": [
571+
"import os\n",
572+
"import getpass\n",
573+
"\n",
574+
"if \"AWS_ACCESS_KEY_ID\" not in os.environ:\n",
575+
" os.environ[\"AWS_ACCESS_KEY_ID\"] = getpass.getpass(\"Enter AWS Access Key ID: \")\n",
576+
"if \"AWS_SECRET_ACCESS_KEY\" not in os.environ:\n",
577+
" os.environ[\"AWS_SECRET_ACCESS_KEY\"] = getpass.getpass(\"Enter AWS Secret Key: \")\n",
578+
"\n",
579+
"os.environ[\"AWS_REGION\"] = \"us-east-1\" # Change as needed"
580+
]
581+
},
582+
{
583+
"cell_type": "markdown",
584+
"metadata": {},
585+
"source": [
586+
"#### Create embeddings:"
587+
]
588+
},
589+
{
590+
"cell_type": "code",
591+
"execution_count": null,
592+
"metadata": {},
593+
"outputs": [],
594+
"source": [
595+
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
596+
"\n",
597+
"bedrock = BedrockTextVectorizer(\n",
598+
" model=\"amazon.titan-embed-text-v2:0\"\n",
599+
")\n",
600+
"\n",
601+
"# Single embedding\n",
602+
"text = \"This is a test sentence.\"\n",
603+
"embedding = bedrock.embed(text)\n",
604+
"print(f\"Vector dimensions: {len(embedding)}\")\n",
605+
"\n",
606+
"# Multiple embeddings\n",
607+
"sentences = [\n",
608+
" \"That is a happy dog\",\n",
609+
" \"That is a happy person\",\n",
610+
" \"Today is a sunny day\"\n",
611+
"]\n",
612+
"embeddings = bedrock.embed_many(sentences)"
613+
]
614+
},
544615
{
545616
"cell_type": "markdown",
546617
"metadata": {},
@@ -691,7 +762,7 @@
691762
},
692763
{
693764
"cell_type": "code",
694-
"execution_count": 17,
765+
"execution_count": null,
695766
"metadata": {},
696767
"outputs": [
697768
{
@@ -710,9 +781,10 @@
710781
"source": [
711782
"# load expects an iterable of dictionaries where\n",
712783
"# the vector is stored as a bytes buffer\n",
784+
"from redisvl.redis.utils import array_to_buffer\n",
713785
"\n",
714786
"data = [{\"text\": t,\n",
715-
" \"embedding\": v}\n",
787+
" \"embedding\": array_to_buffer(v, dtype=\"float32\")}\n",
716788
" for t, v in zip(sentences, embeddings)]\n",
717789
"\n",
718790
"index.load(data)"

poetry.lock

+278-40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
3232
google-cloud-aiplatform = { version = ">=1.26", optional = true }
3333
cohere = { version = ">=4.44", optional = true }
3434
mistralai = { version = ">=0.2.0", optional = true }
35+
boto3 = { version = ">=1.34.0", optional = true }
3536

3637
[tool.poetry.extras]
3738
openai = ["openai"]
3839
sentence-transformers = ["sentence-transformers"]
3940
google_cloud_aiplatform = ["google_cloud_aiplatform"]
4041
cohere = ["cohere"]
4142
mistralai = ["mistralai"]
43+
bedrock = ["boto3"]
4244

4345
[tool.poetry.group.dev.dependencies]
4446
black = ">=20.8b1"

redisvl/utils/vectorize/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers
22
from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer
3+
from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer
34
from redisvl.utils.vectorize.text.cohere import CohereTextVectorizer
45
from redisvl.utils.vectorize.text.custom import CustomTextVectorizer
56
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
@@ -8,14 +9,15 @@
89
from redisvl.utils.vectorize.text.vertexai import VertexAITextVectorizer
910

1011
__all__ = [
11-
"BaseVectrorizer",
12+
"BaseVectorizer",
1213
"CohereTextVectorizer",
1314
"HFTextVectorizer",
1415
"OpenAITextVectorizer",
1516
"VertexAITextVectorizer",
1617
"AzureOpenAITextVectorizer",
1718
"MistralAITextVectorizer",
1819
"CustomTextVectorizer",
20+
"BedrockTextVectorizer",
1921
]
2022

2123

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import json
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional
4+
5+
from pydantic.v1 import PrivateAttr
6+
from tenacity import retry, stop_after_attempt, wait_random_exponential
7+
from tenacity.retry import retry_if_not_exception_type
8+
9+
from redisvl.utils.vectorize.base import BaseVectorizer
10+
11+
12+
class BedrockTextVectorizer(BaseVectorizer):
13+
"""The AmazonBedrockTextVectorizer class utilizes Amazon Bedrock's API to generate
14+
embeddings for text data.
15+
16+
This vectorizer is designed to interact with Amazon Bedrock API,
17+
requiring AWS credentials for authentication. The credentials can be provided
18+
directly in the `api_config` dictionary or through environment variables:
19+
- AWS_ACCESS_KEY_ID
20+
- AWS_SECRET_ACCESS_KEY
21+
- AWS_REGION (defaults to us-east-1)
22+
23+
The vectorizer supports synchronous operations with batch processing and
24+
preprocessing capabilities.
25+
26+
.. code-block:: python
27+
28+
# Initialize with explicit credentials
29+
vectorizer = AmazonBedrockTextVectorizer(
30+
model="amazon.titan-embed-text-v2:0",
31+
api_config={
32+
"aws_access_key_id": "your_access_key",
33+
"aws_secret_access_key": "your_secret_key",
34+
"aws_region": "us-east-1"
35+
}
36+
)
37+
38+
# Initialize using environment variables
39+
vectorizer = AmazonBedrockTextVectorizer()
40+
41+
# Generate embeddings
42+
embedding = vectorizer.embed("Hello, world!")
43+
embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2)
44+
"""
45+
46+
_client: Any = PrivateAttr()
47+
48+
def __init__(
49+
self,
50+
model: str = "amazon.titan-embed-text-v2:0",
51+
api_config: Optional[Dict[str, str]] = None,
52+
) -> None:
53+
"""Initialize the AWS Bedrock Vectorizer.
54+
55+
Args:
56+
model (str): The Bedrock model ID to use. Defaults to amazon.titan-embed-text-v2:0
57+
api_config (Optional[Dict[str, str]]): AWS credentials and config.
58+
Can include: aws_access_key_id, aws_secret_access_key, aws_region
59+
If not provided, will use environment variables.
60+
61+
Raises:
62+
ValueError: If credentials are not provided in config or environment.
63+
ImportError: If boto3 is not installed.
64+
"""
65+
try:
66+
import boto3 # type: ignore
67+
except ImportError:
68+
raise ImportError(
69+
"Amazon Bedrock vectorizer requires boto3. "
70+
"Please install with `pip install boto3`"
71+
)
72+
73+
if api_config is None:
74+
api_config = {}
75+
76+
aws_access_key_id = api_config.get(
77+
"aws_access_key_id", os.getenv("AWS_ACCESS_KEY_ID")
78+
)
79+
aws_secret_access_key = api_config.get(
80+
"aws_secret_access_key", os.getenv("AWS_SECRET_ACCESS_KEY")
81+
)
82+
aws_region = api_config.get("aws_region", os.getenv("AWS_REGION", "us-east-1"))
83+
84+
if not aws_access_key_id or not aws_secret_access_key:
85+
raise ValueError(
86+
"AWS credentials required. Provide via api_config or environment variables "
87+
"AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
88+
)
89+
90+
self._client = boto3.client(
91+
"bedrock-runtime",
92+
aws_access_key_id=aws_access_key_id,
93+
aws_secret_access_key=aws_secret_access_key,
94+
region_name=aws_region,
95+
)
96+
97+
super().__init__(model=model, dims=self._set_model_dims(model))
98+
99+
def _set_model_dims(self, model: str) -> int:
100+
"""Initialize model and determine embedding dimensions."""
101+
try:
102+
response = self._client.invoke_model(
103+
modelId=model, body=json.dumps({"inputText": "dimension test"})
104+
)
105+
response_body = json.loads(response["body"].read())
106+
embedding = response_body["embedding"]
107+
return len(embedding)
108+
except Exception as e:
109+
raise ValueError(f"Error initializing Bedrock model: {str(e)}")
110+
111+
@retry(
112+
wait=wait_random_exponential(min=1, max=60),
113+
stop=stop_after_attempt(6),
114+
retry=retry_if_not_exception_type(TypeError),
115+
)
116+
def embed(
117+
self,
118+
text: str,
119+
preprocess: Optional[Callable] = None,
120+
as_buffer: bool = False,
121+
**kwargs,
122+
) -> List[float]:
123+
"""Embed a chunk of text using Amazon Bedrock.
124+
125+
Args:
126+
text (str): Text to embed.
127+
preprocess (Optional[Callable]): Optional preprocessing function.
128+
as_buffer (bool): Whether to return as byte buffer.
129+
130+
Returns:
131+
List[float]: The embedding vector.
132+
133+
Raises:
134+
TypeError: If text is not a string.
135+
"""
136+
if not isinstance(text, str):
137+
raise TypeError("Text must be a string")
138+
139+
if preprocess:
140+
text = preprocess(text)
141+
142+
response = self._client.invoke_model(
143+
modelId=self.model, body=json.dumps({"inputText": text})
144+
)
145+
response_body = json.loads(response["body"].read())
146+
embedding = response_body["embedding"]
147+
148+
dtype = kwargs.pop("dtype", None)
149+
return self._process_embedding(embedding, as_buffer, dtype)
150+
151+
@retry(
152+
wait=wait_random_exponential(min=1, max=60),
153+
stop=stop_after_attempt(6),
154+
retry=retry_if_not_exception_type(TypeError),
155+
)
156+
def embed_many(
157+
self,
158+
texts: List[str],
159+
preprocess: Optional[Callable] = None,
160+
batch_size: int = 10,
161+
as_buffer: bool = False,
162+
**kwargs,
163+
) -> List[List[float]]:
164+
"""Embed multiple texts using Amazon Bedrock.
165+
166+
Args:
167+
texts (List[str]): List of texts to embed.
168+
preprocess (Optional[Callable]): Optional preprocessing function.
169+
batch_size (int): Size of batches for processing.
170+
as_buffer (bool): Whether to return as byte buffers.
171+
172+
Returns:
173+
List[List[float]]: List of embedding vectors.
174+
175+
Raises:
176+
TypeError: If texts is not a list of strings.
177+
"""
178+
if not isinstance(texts, list):
179+
raise TypeError("Texts must be a list of strings")
180+
if texts and not isinstance(texts[0], str):
181+
raise TypeError("Texts must be a list of strings")
182+
183+
embeddings: List[List[float]] = []
184+
dtype = kwargs.pop("dtype", None)
185+
186+
for batch in self.batchify(texts, batch_size, preprocess):
187+
# Process each text in the batch individually since Bedrock
188+
# doesn't support batch embedding
189+
batch_embeddings = []
190+
for text in batch:
191+
response = self._client.invoke_model(
192+
modelId=self.model, body=json.dumps({"inputText": text})
193+
)
194+
response_body = json.loads(response["body"].read())
195+
batch_embeddings.append(response_body["embedding"])
196+
197+
embeddings.extend(
198+
[
199+
self._process_embedding(embedding, as_buffer, dtype)
200+
for embedding in batch_embeddings
201+
]
202+
)
203+
204+
return embeddings
205+
206+
@property
207+
def type(self) -> str:
208+
return "bedrock"

0 commit comments

Comments
 (0)