diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 10346b7d..6b6b47d6 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -61,6 +61,7 @@ jobs: GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} @@ -80,6 +81,7 @@ jobs: GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} @@ -92,4 +94,4 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml - fail_ci_if_error: false \ No newline at end of file + fail_ci_if_error: false diff --git a/README.md b/README.md index 15c42236..de5324e0 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ Commands: ### ⚡ Community Integrations Integrate with popular embedding models and providers to greatly simplify the process of vectorizing unstructured data for your index and queries: - [Cohere](https://www.redisvl.com/api/vectorizer/html#coheretextvectorizer) +- [Mistral](https://www.redisvl.com/api/vectorizer/html#mistralaitextvectorizer) - [OpenAI](https://www.redisvl.com/api/vectorizer.html#openaitextvectorizer) - [HuggingFace](https://www.redisvl.com/api/vectorizer.html#hftextvectorizer) - [GCP VertexAI](https://www.redisvl.com/api/vectorizer.html#vertexaitextvectorizer) diff --git a/conftest.py b/conftest.py index 204ac177..18e63cfa 100644 --- a/conftest.py +++ b/conftest.py @@ -59,6 +59,10 @@ def azure_endpoint(): def cohere_key(): return os.getenv("COHERE_API_KEY") +@pytest.fixture +def mistral_key(): + return os.getenv("MISTRAL_API_KEY") + @pytest.fixture def gcp_location(): return os.getenv("GCP_LOCATION") @@ -133,4 +137,4 @@ def sample_data(): def clear_db(redis): redis.flushall() yield - redis.flushall() \ No newline at end of file + redis.flushall() diff --git a/docs/user_guide/vectorizers_04.ipynb b/docs/user_guide/vectorizers_04.ipynb index 2dde2bad..c2e70565 100644 --- a/docs/user_guide/vectorizers_04.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -12,6 +12,7 @@ "2. HuggingFace\n", "3. Vertex AI\n", "4. Cohere\n", + "5. Mistral AI\n", "\n", "Before running this notebook, be sure to\n", "1. Have installed ``redisvl`` and have that environment active for this notebook.\n", @@ -500,6 +501,44 @@ "Learn more about using RedisVL and Cohere together through [this dedicated user guide](https://docs.cohere.com/docs/redis-and-cohere)." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mistral AI\n", + "\n", + "[Mistral](https://console.mistral.ai/) offers LLM and embedding APIs you to implement into your product. The `MistralAITextVectorizer` makes it simple to use RedisVL with their embeddings model. You will need to install `mistralai`.\n", + "\n", + "```bash\n", + "pip install mistralai\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vector dimensions: 1024\n", + "[-0.02801513671875, 0.02532958984375, 0.04278564453125, 0.0185699462890625, 0.041015625, 0.006053924560546875, 0.03607177734375, -0.0030155181884765625, 0.0033893585205078125, -0.01390838623046875]\n" + ] + } + ], + "source": [ + "from redisvl.utils.vectorize import MistralAITextVectorizer\n", + "\n", + "mistral = MistralAITextVectorizer()\n", + "\n", + "# embed a sentence using their asyncronous method\n", + "test = await mistral.aembed(\"This is a test sentence.\")\n", + "print(\"vector dimensions:\", len(test))\n", + "print(test[:10])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -658,7 +697,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.2" }, "orig_nbformat": 4, "vscode": { diff --git a/poetry.lock b/poetry.lock index afee9c3d..24c3048f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2306,6 +2306,22 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mistralai" +version = "0.4.1" +description = "" +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "mistralai-0.4.1-py3-none-any.whl", hash = "sha256:c11d636093c9eec923f00ac9dff13e4619eb751d44d7a3fea5b665a0e8f99f93"}, + {file = "mistralai-0.4.1.tar.gz", hash = "sha256:22a88c24b9e3176021b466c1d78e6582eef700688803460fd449254fb7647979"}, +] + +[package.dependencies] +httpx = ">=0.25,<1" +orjson = ">=3.9.10,<3.11" +pydantic = ">=2.5.2,<3" + [[package]] name = "mistune" version = "3.0.2" @@ -2805,6 +2821,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -2843,6 +2860,61 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "orjson" +version = "3.10.5" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = true +python-versions = ">=3.8" +files = [ + {file = "orjson-3.10.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:545d493c1f560d5ccfc134803ceb8955a14c3fcb47bbb4b2fee0232646d0b932"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4324929c2dd917598212bfd554757feca3e5e0fa60da08be11b4aa8b90013c1"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c13ca5e2ddded0ce6a927ea5a9f27cae77eee4c75547b4297252cb20c4d30e6"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b6c8e30adfa52c025f042a87f450a6b9ea29649d828e0fec4858ed5e6caecf63"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:338fd4f071b242f26e9ca802f443edc588fa4ab60bfa81f38beaedf42eda226c"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6970ed7a3126cfed873c5d21ece1cd5d6f83ca6c9afb71bbae21a0b034588d96"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:235dadefb793ad12f7fa11e98a480db1f7c6469ff9e3da5e73c7809c700d746b"}, + {file = "orjson-3.10.5-cp310-none-win32.whl", hash = "sha256:be79e2393679eda6a590638abda16d167754393f5d0850dcbca2d0c3735cebe2"}, + {file = "orjson-3.10.5-cp310-none-win_amd64.whl", hash = "sha256:c4a65310ccb5c9910c47b078ba78e2787cb3878cdded1702ac3d0da71ddc5228"}, + {file = "orjson-3.10.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cdf7365063e80899ae3a697def1277c17a7df7ccfc979990a403dfe77bb54d40"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b68742c469745d0e6ca5724506858f75e2f1e5b59a4315861f9e2b1df77775a"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d10cc1b594951522e35a3463da19e899abe6ca95f3c84c69e9e901e0bd93d38"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcbe82b35d1ac43b0d84072408330fd3295c2896973112d495e7234f7e3da2e1"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c0eb7e0c75e1e486c7563fe231b40fdd658a035ae125c6ba651ca3b07936f5"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:53ed1c879b10de56f35daf06dbc4a0d9a5db98f6ee853c2dbd3ee9d13e6f302f"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:099e81a5975237fda3100f918839af95f42f981447ba8f47adb7b6a3cdb078fa"}, + {file = "orjson-3.10.5-cp311-none-win32.whl", hash = "sha256:1146bf85ea37ac421594107195db8bc77104f74bc83e8ee21a2e58596bfb2f04"}, + {file = "orjson-3.10.5-cp311-none-win_amd64.whl", hash = "sha256:36a10f43c5f3a55c2f680efe07aa93ef4a342d2960dd2b1b7ea2dd764fe4a37c"}, + {file = "orjson-3.10.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:68f85ecae7af14a585a563ac741b0547a3f291de81cd1e20903e79f25170458f"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28afa96f496474ce60d3340fe8d9a263aa93ea01201cd2bad844c45cd21f5268"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cd684927af3e11b6e754df80b9ffafd9fb6adcaa9d3e8fdd5891be5a5cad51e"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d21b9983da032505f7050795e98b5d9eee0df903258951566ecc358f6696969"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ad1de7fef79736dde8c3554e75361ec351158a906d747bd901a52a5c9c8d24b"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d97531cdfe9bdd76d492e69800afd97e5930cb0da6a825646667b2c6c6c0211"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d69858c32f09c3e1ce44b617b3ebba1aba030e777000ebdf72b0d8e365d0b2b3"}, + {file = "orjson-3.10.5-cp312-none-win32.whl", hash = "sha256:64c9cc089f127e5875901ac05e5c25aa13cfa5dbbbd9602bda51e5c611d6e3e2"}, + {file = "orjson-3.10.5-cp312-none-win_amd64.whl", hash = "sha256:b2efbd67feff8c1f7728937c0d7f6ca8c25ec81373dc8db4ef394c1d93d13dc5"}, + {file = "orjson-3.10.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:03b565c3b93f5d6e001db48b747d31ea3819b89abf041ee10ac6988886d18e01"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:584c902ec19ab7928fd5add1783c909094cc53f31ac7acfada817b0847975f26"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a35455cc0b0b3a1eaf67224035f5388591ec72b9b6136d66b49a553ce9eb1e6"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1670fe88b116c2745a3a30b0f099b699a02bb3482c2591514baf5433819e4f4d"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:185c394ef45b18b9a7d8e8f333606e2e8194a50c6e3c664215aae8cf42c5385e"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ca0b3a94ac8d3886c9581b9f9de3ce858263865fdaa383fbc31c310b9eac07c9"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:dfc91d4720d48e2a709e9c368d5125b4b5899dced34b5400c3837dadc7d6271b"}, + {file = "orjson-3.10.5-cp38-none-win32.whl", hash = "sha256:c05f16701ab2a4ca146d0bca950af254cb7c02f3c01fca8efbbad82d23b3d9d4"}, + {file = "orjson-3.10.5-cp38-none-win_amd64.whl", hash = "sha256:8a11d459338f96a9aa7f232ba95679fc0c7cedbd1b990d736467894210205c09"}, + {file = "orjson-3.10.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:85c89131d7b3218db1b24c4abecea92fd6c7f9fab87441cfc342d3acc725d807"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66215277a230c456f9038d5e2d84778141643207f85336ef8d2a9da26bd7ca"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51bbcdea96cdefa4a9b4461e690c75ad4e33796530d182bdd5c38980202c134a"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbead71dbe65f959b7bd8cf91e0e11d5338033eba34c114f69078d59827ee139"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df58d206e78c40da118a8c14fc189207fffdcb1f21b3b4c9c0c18e839b5a214"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4057c3b511bb8aef605616bd3f1f002a697c7e4da6adf095ca5b84c0fd43595"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b39e006b00c57125ab974362e740c14a0c6a66ff695bff44615dcf4a70ce2b86"}, + {file = "orjson-3.10.5-cp39-none-win32.whl", hash = "sha256:eded5138cc565a9d618e111c6d5c2547bbdd951114eb822f7f6309e04db0fb47"}, + {file = "orjson-3.10.5-cp39-none-win_amd64.whl", hash = "sha256:cc28e90a7cae7fcba2493953cff61da5a52950e78dc2dacfe931a317ee3d8de7"}, + {file = "orjson-3.10.5.tar.gz", hash = "sha256:7a5baef8a4284405d96c90c7c62b755e9ef1ada84c2406c24a9ebec86b89f46d"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -5444,10 +5516,11 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cohere = ["cohere"] google-cloud-aiplatform = ["google-cloud-aiplatform"] +mistralai = ["mistralai"] openai = ["openai"] sentence-transformers = ["sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "dc48145088ceb6ad88c87000d25f85e0ed07fd2778c33addbdb4f18da4505555" +content-hash = "be9b5df2ff3600823749e4d0bfffe148c6bb04f88fa287a3dfae712ade9fd06e" diff --git a/pyproject.toml b/pyproject.toml index a0bd132e..91f41677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,14 @@ openai = { version = ">=1.13.0", optional = true } sentence-transformers = { version = ">=2.2.2", optional = true } google-cloud-aiplatform = { version = ">=1.26", optional = true } cohere = { version = ">=4.44", optional = true } +mistralai = { version = ">=0.2.0", optional = true } [tool.poetry.extras] openai = ["openai"] sentence-transformers = ["sentence-transformers"] google_cloud_aiplatform = ["google_cloud_aiplatform"] cohere = ["cohere"] +mistralai = ["mistralai"] [tool.poetry.group.dev.dependencies] black = ">=20.8b1" @@ -116,4 +118,4 @@ directory = "htmlcov" [tool.mypy] warn_unused_configs = true -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true diff --git a/redisvl/utils/vectorize/__init__.py b/redisvl/utils/vectorize/__init__.py index ea9d7bee..6b9e3af9 100644 --- a/redisvl/utils/vectorize/__init__.py +++ b/redisvl/utils/vectorize/__init__.py @@ -2,6 +2,7 @@ from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer from redisvl.utils.vectorize.text.cohere import CohereTextVectorizer from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer +from redisvl.utils.vectorize.text.mistral import MistralAITextVectorizer from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer from redisvl.utils.vectorize.text.vertexai import VertexAITextVectorizer @@ -12,4 +13,5 @@ "OpenAITextVectorizer", "VertexAITextVectorizer", "AzureOpenAITextVectorizer", + "MistralAITextVectorizer", ] diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py new file mode 100644 index 00000000..4bb9c4fd --- /dev/null +++ b/redisvl/utils/vectorize/text/mistral.py @@ -0,0 +1,262 @@ +import os +from typing import Any, Callable, Dict, List, Optional + +from pydantic.v1 import PrivateAttr +from tenacity import retry, stop_after_attempt, wait_random_exponential +from tenacity.retry import retry_if_not_exception_type + +from redisvl.utils.vectorize.base import BaseVectorizer + +# ignore that mistralai isn't imported +# mypy: disable-error-code="name-defined" + + +class MistralAITextVectorizer(BaseVectorizer): + """The MistralAITextVectorizer class utilizes MistralAI's API to generate + embeddings for text data. + + This vectorizer is designed to interact with Mistral's embeddings API, + requiring an API key for authentication. The key can be provided directly + in the `api_config` dictionary or through the `MISTRAL_API_KEY` environment + variable. Users must obtain an API key from Mistral's website + (https://console.mistral.ai/). Additionally, the `mistralai` python client + must be installed with `pip install mistralai`. + + The vectorizer supports both synchronous and asynchronous operations, + allowing for batch processing of texts and flexibility in handling + preprocessing tasks. + + .. code-block:: python + + # Synchronous embedding of a single text + vectorizer = MistralAITextVectorizer( + model="mistral-embed" + api_config={"api_key": "your_api_key"} # OR set MISTRAL_API_KEY in your env + ) + embedding = vectorizer.embed("Hello, world!") + + # Asynchronous batch embedding of multiple texts + embeddings = await vectorizer.aembed_many( + ["Hello, world!", "How are you?"], + batch_size=2 + ) + + """ + + _client: Any = PrivateAttr() + _aclient: Any = PrivateAttr() + + def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = None): + """Initialize the MistralAI vectorizer. + + Args: + model (str): Model to use for embedding. Defaults to + 'text-embedding-ada-002'. + api_config (Optional[Dict], optional): Dictionary containing the + API key. Defaults to None. + + Raises: + ImportError: If the mistralai library is not installed. + ValueError: If the Mistral API key is not provided. + """ + self._initialize_clients(api_config) + super().__init__(model=model, dims=self._set_model_dims(model)) + + def _initialize_clients(self, api_config: Optional[Dict]): + """ + Setup the Mistral clients using the provided API key or an + environment variable. + """ + # Dynamic import of the mistralai module + try: + from mistralai.async_client import MistralAsyncClient + from mistralai.client import MistralClient + except ImportError: + raise ImportError( + "MistralAI vectorizer requires the mistralai library. \ + Please install with `pip install mistralai`" + ) + + # Fetch the API key from api_config or environment variable + api_key = ( + api_config.get("api_key") if api_config else os.getenv("MISTRAL_API_KEY") + ) + if not api_key: + raise ValueError( + "MISTRAL API key is required. " + "Provide it in api_config or set the MISTRAL_API_KEY\ + environment variable." + ) + + self._client = MistralClient(api_key=api_key) + self._aclient = MistralAsyncClient(api_key=api_key) + + def _set_model_dims(self, model) -> int: + try: + embedding = ( + self._client.embeddings(model=model, input=["dimension test"]) + .data[0] + .embedding + ) + except (KeyError, IndexError) as ke: + raise ValueError(f"Unexpected response from the MISTRAL API: {str(ke)}") + except Exception as e: # pylint: disable=broad-except + # fall back (TODO get more specific) + raise ValueError(f"Error setting embedding model dimensions: {str(e)}") + return len(embedding) + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + retry=retry_if_not_exception_type(TypeError), + ) + def embed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 10, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + """Embed many chunks of texts using the Mistral API. + + Args: + texts (List[str]): List of text chunks to embed. + preprocess (Optional[Callable], optional): Optional preprocessing + callable to perform before vectorization. Defaults to None. + batch_size (int, optional): Batch size of texts to use when creating + embeddings. Defaults to 10. + as_buffer (bool, optional): Whether to convert the raw embedding + to a byte string. Defaults to False. + + Returns: + List[List[float]]: List of embeddings. + + Raises: + TypeError: If the wrong input type is passed in for the test. + """ + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if len(texts) > 0 and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") + + embeddings: List = [] + for batch in self.batchify(texts, batch_size, preprocess): + response = self._client.embeddings(model=self.model, input=batch) + embeddings += [ + self._process_embedding(r.embedding, as_buffer) for r in response.data + ] + return embeddings + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + retry=retry_if_not_exception_type(TypeError), + ) + def embed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + """Embed a chunk of text using the Mistral API. + + Args: + text (str): Chunk of text to embed. + preprocess (Optional[Callable], optional): Optional preprocessing callable to + perform before vectorization. Defaults to None. + as_buffer (bool, optional): Whether to convert the raw embedding + to a byte string. Defaults to False. + + Returns: + List[float]: Embedding. + + Raises: + TypeError: If the wrong input type is passed in for the test. + """ + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") + + if preprocess: + text = preprocess(text) + result = self._client.embeddings(model=self.model, input=[text]) + return self._process_embedding(result.data[0].embedding, as_buffer) + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + retry=retry_if_not_exception_type(TypeError), + ) + async def aembed_many( + self, + texts: List[str], + preprocess: Optional[Callable] = None, + batch_size: int = 1000, + as_buffer: bool = False, + **kwargs, + ) -> List[List[float]]: + """Asynchronously embed many chunks of texts using the Mistral API. + + Args: + texts (List[str]): List of text chunks to embed. + preprocess (Optional[Callable], optional): Optional preprocessing callable to + perform before vectorization. Defaults to None. + batch_size (int, optional): Batch size of texts to use when creating + embeddings. Defaults to 10. + as_buffer (bool, optional): Whether to convert the raw embedding + to a byte string. Defaults to False. + + Returns: + List[List[float]]: List of embeddings. + + Raises: + TypeError: If the wrong input type is passed in for the test. + """ + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if len(texts) > 0 and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") + + embeddings: List = [] + for batch in self.batchify(texts, batch_size, preprocess): + response = await self._aclient.embeddings(model=self.model, input=batch) + embeddings += [ + self._process_embedding(r.embedding, as_buffer) for r in response.data + ] + return embeddings + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + retry=retry_if_not_exception_type(TypeError), + ) + async def aembed( + self, + text: str, + preprocess: Optional[Callable] = None, + as_buffer: bool = False, + **kwargs, + ) -> List[float]: + """Asynchronously embed a chunk of text using the MistralAPI. + + Args: + text (str): Chunk of text to embed. + preprocess (Optional[Callable], optional): Optional preprocessing callable to + perform before vectorization. Defaults to None. + as_buffer (bool, optional): Whether to convert the raw embedding + to a byte string. Defaults to False. + + Returns: + List[float]: Embedding. + + Raises: + TypeError: If the wrong input type is passed in for the test. + """ + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") + + if preprocess: + text = preprocess(text) + result = await self._aclient.embeddings(model=self.model, input=[text]) + return self._process_embedding(result.data[0].embedding, as_buffer) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 23952c65..ffc694bf 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -6,6 +6,7 @@ AzureOpenAITextVectorizer, CohereTextVectorizer, HFTextVectorizer, + MistralAITextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer, ) @@ -25,6 +26,7 @@ def skip_vectorizer() -> bool: VertexAITextVectorizer, CohereTextVectorizer, AzureOpenAITextVectorizer, + MistralAITextVectorizer, ] ) def vectorizer(request, skip_vectorizer): @@ -39,6 +41,8 @@ def vectorizer(request, skip_vectorizer): return request.param() elif request.param == CohereTextVectorizer: return request.param() + elif request.param == MistralAITextVectorizer: + return request.param() elif request.param == AzureOpenAITextVectorizer: return request.param( model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") @@ -81,7 +85,7 @@ def test_vectorizer_bad_input(vectorizer): vectorizer.embed_many(42) -@pytest.fixture(params=[OpenAITextVectorizer]) +@pytest.fixture(params=[OpenAITextVectorizer, MistralAITextVectorizer]) def avectorizer(request, skip_vectorizer): if skip_vectorizer: pytest.skip("Skipping vectorizer instantiation...") @@ -89,6 +93,8 @@ def avectorizer(request, skip_vectorizer): # Here we use actual models for integration test if request.param == OpenAITextVectorizer: return request.param() + elif request.param == MistralAITextVectorizer: + return request.param() @pytest.mark.asyncio