From 599cd63cbdd80403ded8073c6c8c0e69eb0843f0 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 8 Oct 2025 15:17:23 -0700 Subject: [PATCH 1/2] Fix for cohere v4 embeddings --- libs/aws/langchain_aws/embeddings/bedrock.py | 49 ++++++-- .../embeddings/test_bedrock_embeddings.py | 15 ++- .../tests/unit_tests/embeddings/__init__.py | 1 + .../embeddings/test_bedrock_cohere.py | 111 ++++++++++++++++++ 4 files changed, 164 insertions(+), 12 deletions(-) create mode 100644 libs/aws/tests/unit_tests/embeddings/__init__.py create mode 100644 libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py diff --git a/libs/aws/langchain_aws/embeddings/bedrock.py b/libs/aws/langchain_aws/embeddings/bedrock.py index 481a7f20..51b09558 100644 --- a/libs/aws/langchain_aws/embeddings/bedrock.py +++ b/libs/aws/langchain_aws/embeddings/bedrock.py @@ -152,6 +152,11 @@ def _inferred_provider(self) -> str: parts = self.model_id.split(".") return parts[1] if parts[0] in regions else parts[0] + @property + def _is_cohere_v4(self) -> bool: + """Check if the model is Cohere Embed v4.""" + return "cohere.embed-v4" in self.model_id + @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" @@ -189,7 +194,12 @@ def _embedding_func( embeddings = response_body.get("embeddings") if embeddings is None: raise ValueError("No embeddings returned from model") - return embeddings[0] + # Embed v3 and v4 schemas + if isinstance(embeddings, dict) and "float" in embeddings: + processed_embeddings = embeddings["float"] + else: + processed_embeddings = embeddings + return processed_embeddings[0] else: # includes common provider == "amazon" response_body = self._invoke_model( @@ -207,7 +217,9 @@ def _cohere_multi_embedding(self, texts: List[str]) -> List[List[float]]: results: List[List[float]] = [] # Iterate through the list of strings in batches - for text_batch in _batch_cohere_embedding_texts(texts): + for text_batch in _batch_cohere_embedding_texts( + texts, is_v4=self._is_cohere_v4 + ): batch_embeddings = self._invoke_model( input_body={ "input_type": "search_document", @@ -344,17 +356,35 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: return list(result) -def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None, None]: +def _batch_cohere_embedding_texts( + texts: List[str], is_v4: bool = False +) -> Generator[List[str], None, None]: """Batches a set of texts into chunks acceptable for the Cohere embedding API. - Chunks of at most 96 items, or 2048 characters. + For Cohere Embed v3: Chunks of at most 96 items, or 2048 characters. + For Cohere Embed v4: Chunks of at most 96 items, or ~512,000 characters + (approx 128K tokens). """ - # Cohere embeddings want a maximum of 96 items and 2048 characters - # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html max_items = 96 - max_chars = 2048 + if is_v4: + # Cohere Embed v4 supports up to 128K tokens per input + # Using conservative estimate of ~4 chars per token = ~512K chars + # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed-v4.html + max_chars = 512_000 + char_limit_msg = ( + "The Cohere Embed v4 embedding API does not support texts longer than " + "approximately 128K tokens (~512,000 characters)." + ) + else: + # Cohere Embed v3 limit + # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html + max_chars = 2048 + char_limit_msg = ( + "The Cohere embedding API does not support texts longer than " + "2048 characters." + ) # Initialize batches current_batch: List[str] = [] @@ -364,10 +394,7 @@ def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None text_len = len(text) if text_len > max_chars: - raise ValueError( - "The Cohere embedding API does not support texts longer than " - "2048 characters." - ) + raise ValueError(char_limit_msg) # Check if adding the current string would exceed the limits if len(current_batch) >= max_items or current_chars + text_len > max_chars: diff --git a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py index e2da840b..10aeeb54 100644 --- a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py +++ b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py @@ -36,7 +36,7 @@ def cohere_embeddings_v4() -> BedrockEmbeddings: @pytest.fixture( params=[ ("cohere.embed-english-v3", 1024), - # ("us.cohere.embed-v4:0", 1536), + #("us.cohere.embed-v4:0", 1536), ] ) def cohere_embeddings(request) -> tuple[BedrockEmbeddings, int]: @@ -200,3 +200,16 @@ def test_bedrock_embedding_provider_arg( assert cohere_embeddings_v3._inferred_provider == "cohere" assert cohere_embeddings_v4._inferred_provider == "cohere" assert cohere_embeddings_model_arn._inferred_provider == "cohere" + + +#@pytest.mark.scheduled +@pytest.mark.skip(reason="CI does not have access to v4 embeddings.") +def test_bedrock_cohere_v4_large_input(cohere_embeddings_v4) -> None: + """Test that Cohere v4 can handle inputs larger than v3's 2048 char limit.""" + # Create a text slightly larger than v3's 2048 char limit + large_text = "x" * 3000 # 3000 characters > 2048 limit of v3 + + # This should work with v4 (would fail with v3) + output = cohere_embeddings_v4.embed_documents([large_text]) + assert len(output) == 1 + assert len(output[0]) == 1536 # v4 embedding dimension diff --git a/libs/aws/tests/unit_tests/embeddings/__init__.py b/libs/aws/tests/unit_tests/embeddings/__init__.py new file mode 100644 index 00000000..ea763deb --- /dev/null +++ b/libs/aws/tests/unit_tests/embeddings/__init__.py @@ -0,0 +1 @@ +"""Unit tests for embeddings.""" \ No newline at end of file diff --git a/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py b/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py new file mode 100644 index 00000000..c757e0e5 --- /dev/null +++ b/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py @@ -0,0 +1,111 @@ +"""Test Cohere v4 embedding fixes.""" + +import pytest +from unittest.mock import Mock, patch +from langchain_aws.embeddings.bedrock import ( + BedrockEmbeddings, + _batch_cohere_embedding_texts, +) + + +class TestCohereV4Fixes: + """Test fixes for Cohere v4 embedding support.""" + + def test_is_cohere_v4_property_v4_model(self) -> None: + """Test that _is_cohere_v4 returns True for v4 models.""" + embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0") + assert embeddings._is_cohere_v4 is True + + def test_is_cohere_v4_property_v3_model(self) -> None: + """Test that _is_cohere_v4 returns False for v3 models.""" + embeddings = BedrockEmbeddings(model_id="cohere.embed-english-v3") + assert embeddings._is_cohere_v4 is False + + def test_is_cohere_v4_property_non_cohere_model(self) -> None: + """Test that _is_cohere_v4 returns False for non-Cohere models.""" + embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") + assert embeddings._is_cohere_v4 is False + + def test_batch_cohere_v3_limits(self) -> None: + """Test that v3 batching respects 2048 character limit.""" + # Test text under limit + short_texts = ["hello"] * 10 + batches = list(_batch_cohere_embedding_texts(short_texts, is_v4=False)) + assert len(batches) == 1 + assert len(batches[0]) == 10 + + # Test text over limit + long_text = "x" * 2049 + with pytest.raises(ValueError) as exc_info: + list(_batch_cohere_embedding_texts([long_text], is_v4=False)) + assert "2048 characters" in str(exc_info.value) + + def test_batch_cohere_v4_limits(self) -> None: + """Test that v4 batching respects higher character limit.""" + # Test text that would fail on v3 but pass on v4 + medium_text = "x" * 10000 # > 2048 but << 512k + batches = list(_batch_cohere_embedding_texts([medium_text], is_v4=True)) + assert len(batches) == 1 + assert len(batches[0]) == 1 + + # Test text over v4 limit + huge_text = "x" * 600000 # > 512k chars + with pytest.raises(ValueError) as exc_info: + list(_batch_cohere_embedding_texts([huge_text], is_v4=True)) + assert "128K tokens" in str(exc_info.value) + + @patch('langchain_aws.embeddings.bedrock.create_aws_client') + def test_embedding_func_cohere_v3_schema(self, mock_create_client) -> None: + """Test that _embedding_func handles v3 schema correctly.""" + mock_client = Mock() + mock_create_client.return_value = mock_client + + # Mock v3 response (direct array) + mock_client.invoke_model.return_value = { + "body": Mock(read=lambda: '{"embeddings": [[0.1, 0.2, 0.3]]}') + } + + embeddings = BedrockEmbeddings(model_id="cohere.embed-english-v3") + result = embeddings._embedding_func("test text") + + assert result == [0.1, 0.2, 0.3] + + @patch('langchain_aws.embeddings.bedrock.create_aws_client') + def test_embedding_func_cohere_v4_schema(self, mock_create_client) -> None: + """Test that _embedding_func handles v4 schema correctly.""" + mock_client = Mock() + mock_create_client.return_value = mock_client + + # Mock v4 response (dict with "float" key) + mock_client.invoke_model.return_value = { + "body": Mock(read=lambda: '{"embeddings": {"float": [[0.1, 0.2, 0.3]]}}') + } + + embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0") + result = embeddings._embedding_func("test text") + + assert result == [0.1, 0.2, 0.3] + + @patch('langchain_aws.embeddings.bedrock.create_aws_client') + def test_cohere_multi_embedding_uses_v4_batching(self, mock_create_client) -> None: + """Test that _cohere_multi_embedding passes v4 flag to batching function.""" + mock_client = Mock() + mock_create_client.return_value = mock_client + + mock_client.invoke_model.return_value = { + "body": Mock(read=lambda: '{"embeddings": {"float": [[0.1, 0.2], [0.3, 0.4]]}}') + } + + embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0") + + # Use a text that would fail v3 limits but pass v4 limits + medium_texts = ["x" * 3000, "y" * 3000] + + with patch('langchain_aws.embeddings.bedrock._batch_cohere_embedding_texts') as mock_batch: + mock_batch.return_value = [medium_texts] # Single batch + + result = embeddings._cohere_multi_embedding(medium_texts) + + # Verify the batching function was called with is_v4=True + mock_batch.assert_called_once_with(medium_texts, is_v4=True) + assert len(result) == 2 \ No newline at end of file From 426acbad04644e49f66622745d76b40159bacea9 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 8 Oct 2025 15:22:48 -0700 Subject: [PATCH 2/2] lint --- .../embeddings/test_bedrock_embeddings.py | 6 +-- .../tests/unit_tests/embeddings/__init__.py | 2 +- .../embeddings/test_bedrock_cohere.py | 46 +++++++++++-------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py index 10aeeb54..4d38adbf 100644 --- a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py +++ b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py @@ -36,7 +36,7 @@ def cohere_embeddings_v4() -> BedrockEmbeddings: @pytest.fixture( params=[ ("cohere.embed-english-v3", 1024), - #("us.cohere.embed-v4:0", 1536), + # ("us.cohere.embed-v4:0", 1536), ] ) def cohere_embeddings(request) -> tuple[BedrockEmbeddings, int]: @@ -202,13 +202,13 @@ def test_bedrock_embedding_provider_arg( assert cohere_embeddings_model_arn._inferred_provider == "cohere" -#@pytest.mark.scheduled +# @pytest.mark.scheduled @pytest.mark.skip(reason="CI does not have access to v4 embeddings.") def test_bedrock_cohere_v4_large_input(cohere_embeddings_v4) -> None: """Test that Cohere v4 can handle inputs larger than v3's 2048 char limit.""" # Create a text slightly larger than v3's 2048 char limit large_text = "x" * 3000 # 3000 characters > 2048 limit of v3 - + # This should work with v4 (would fail with v3) output = cohere_embeddings_v4.embed_documents([large_text]) assert len(output) == 1 diff --git a/libs/aws/tests/unit_tests/embeddings/__init__.py b/libs/aws/tests/unit_tests/embeddings/__init__.py index ea763deb..17cd4978 100644 --- a/libs/aws/tests/unit_tests/embeddings/__init__.py +++ b/libs/aws/tests/unit_tests/embeddings/__init__.py @@ -1 +1 @@ -"""Unit tests for embeddings.""" \ No newline at end of file +"""Unit tests for embeddings.""" diff --git a/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py b/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py index c757e0e5..10a39a29 100644 --- a/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py +++ b/libs/aws/tests/unit_tests/embeddings/test_bedrock_cohere.py @@ -1,7 +1,9 @@ """Test Cohere v4 embedding fixes.""" -import pytest from unittest.mock import Mock, patch + +import pytest + from langchain_aws.embeddings.bedrock import ( BedrockEmbeddings, _batch_cohere_embedding_texts, @@ -54,12 +56,12 @@ def test_batch_cohere_v4_limits(self) -> None: list(_batch_cohere_embedding_texts([huge_text], is_v4=True)) assert "128K tokens" in str(exc_info.value) - @patch('langchain_aws.embeddings.bedrock.create_aws_client') - def test_embedding_func_cohere_v3_schema(self, mock_create_client) -> None: + @patch("langchain_aws.embeddings.bedrock.create_aws_client") + def test_embedding_func_cohere_v3_schema(self, mock_create_client: Mock) -> None: """Test that _embedding_func handles v3 schema correctly.""" mock_client = Mock() mock_create_client.return_value = mock_client - + # Mock v3 response (direct array) mock_client.invoke_model.return_value = { "body": Mock(read=lambda: '{"embeddings": [[0.1, 0.2, 0.3]]}') @@ -67,15 +69,15 @@ def test_embedding_func_cohere_v3_schema(self, mock_create_client) -> None: embeddings = BedrockEmbeddings(model_id="cohere.embed-english-v3") result = embeddings._embedding_func("test text") - + assert result == [0.1, 0.2, 0.3] - @patch('langchain_aws.embeddings.bedrock.create_aws_client') - def test_embedding_func_cohere_v4_schema(self, mock_create_client) -> None: + @patch("langchain_aws.embeddings.bedrock.create_aws_client") + def test_embedding_func_cohere_v4_schema(self, mock_create_client: Mock) -> None: """Test that _embedding_func handles v4 schema correctly.""" mock_client = Mock() mock_create_client.return_value = mock_client - + # Mock v4 response (dict with "float" key) mock_client.invoke_model.return_value = { "body": Mock(read=lambda: '{"embeddings": {"float": [[0.1, 0.2, 0.3]]}}') @@ -83,29 +85,35 @@ def test_embedding_func_cohere_v4_schema(self, mock_create_client) -> None: embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0") result = embeddings._embedding_func("test text") - + assert result == [0.1, 0.2, 0.3] - @patch('langchain_aws.embeddings.bedrock.create_aws_client') - def test_cohere_multi_embedding_uses_v4_batching(self, mock_create_client) -> None: + @patch("langchain_aws.embeddings.bedrock.create_aws_client") + def test_cohere_multi_embedding_uses_v4_batching( + self, mock_create_client: Mock + ) -> None: """Test that _cohere_multi_embedding passes v4 flag to batching function.""" mock_client = Mock() mock_create_client.return_value = mock_client - + mock_client.invoke_model.return_value = { - "body": Mock(read=lambda: '{"embeddings": {"float": [[0.1, 0.2], [0.3, 0.4]]}}') + "body": Mock( + read=lambda: '{"embeddings": {"float": [[0.1, 0.2], [0.3, 0.4]]}}' + ) } embeddings = BedrockEmbeddings(model_id="us.cohere.embed-v4:0") - + # Use a text that would fail v3 limits but pass v4 limits medium_texts = ["x" * 3000, "y" * 3000] - - with patch('langchain_aws.embeddings.bedrock._batch_cohere_embedding_texts') as mock_batch: + + with patch( + "langchain_aws.embeddings.bedrock._batch_cohere_embedding_texts" + ) as mock_batch: mock_batch.return_value = [medium_texts] # Single batch - + result = embeddings._cohere_multi_embedding(medium_texts) - + # Verify the batching function was called with is_v4=True mock_batch.assert_called_once_with(medium_texts, is_v4=True) - assert len(result) == 2 \ No newline at end of file + assert len(result) == 2