diff --git a/generative_ai/embeddings/batch_example.py b/generative_ai/embeddings/batch_example.py index 91be92de79b..c18ecf8d523 100644 --- a/generative_ai/embeddings/batch_example.py +++ b/generative_ai/embeddings/batch_example.py @@ -39,7 +39,7 @@ def embed_text_batch() -> BatchPredictionJob: output_uri = OUTPUT_URI textembedding_model = language_models.TextEmbeddingModel.from_pretrained( - "textembedding-gecko@003" + "gemini-embedding-001" ) batch_prediction_job = textembedding_model.batch_predict( diff --git a/generative_ai/embeddings/code_retrieval_example.py b/generative_ai/embeddings/code_retrieval_example.py index a8b7f8d213f..4bd88fa9366 100644 --- a/generative_ai/embeddings/code_retrieval_example.py +++ b/generative_ai/embeddings/code_retrieval_example.py @@ -17,24 +17,31 @@ # [START generativeaionvertexai_embedding_code_retrieval] from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel -MODEL_NAME = "text-embedding-005" -DIMENSIONALITY = 256 +MODEL_NAME = "gemini-embedding-001" +DIMENSIONALITY = 3072 def embed_text( texts: list[str] = ["Retrieve a function that adds two numbers"], task: str = "CODE_RETRIEVAL_QUERY", - model_name: str = "text-embedding-005", - dimensionality: int | None = 256, + model_name: str = "gemini-embedding-001", + dimensionality: int | None = 3072, ) -> list[list[float]]: """Embeds texts with a pre-trained, foundational model.""" model = TextEmbeddingModel.from_pretrained(model_name) - inputs = [TextEmbeddingInput(text, task) for text in texts] kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {} - embeddings = model.get_embeddings(inputs, **kwargs) - # Example response: - # [[0.025890009477734566, -0.05553026497364044, 0.006374752148985863,...], - return [embedding.values for embedding in embeddings] + + embeddings = [] + # gemini-embedding-001 takes one input at a time + for text in texts: + text_input = TextEmbeddingInput(text, task) + embedding = model.get_embeddings([text_input], **kwargs) + print(embedding) + # Example response: + # [[0.006135190837085247, -0.01462465338408947, 0.004978656303137541, ...]] + embeddings.append(embedding[0].values) + + return embeddings if __name__ == "__main__": diff --git a/generative_ai/embeddings/document_retrieval_example.py b/generative_ai/embeddings/document_retrieval_example.py index 9cdeba6220a..71e9d6e0a0c 100644 --- a/generative_ai/embeddings/document_retrieval_example.py +++ b/generative_ai/embeddings/document_retrieval_example.py @@ -28,19 +28,24 @@ def embed_text() -> list[list[float]]: # A list of texts to be embedded. texts = ["banana muffins? ", "banana bread? banana muffins?"] # The dimensionality of the output embeddings. - dimensionality = 256 + dimensionality = 3072 # The task type for embedding. Check the available tasks in the model's documentation. task = "RETRIEVAL_DOCUMENT" - model = TextEmbeddingModel.from_pretrained("text-embedding-005") - inputs = [TextEmbeddingInput(text, task) for text in texts] + model = TextEmbeddingModel.from_pretrained("gemini-embedding-001") kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {} - embeddings = model.get_embeddings(inputs, **kwargs) - print(embeddings) - # Example response: - # [[0.006135190837085247, -0.01462465338408947, 0.004978656303137541, ...], [0.1234434666, ...]], - return [embedding.values for embedding in embeddings] + embeddings = [] + # gemini-embedding-001 takes one input at a time + for text in texts: + text_input = TextEmbeddingInput(text, task) + embedding = model.get_embeddings([text_input], **kwargs) + print(embedding) + # Example response: + # [[0.006135190837085247, -0.01462465338408947, 0.004978656303137541, ...]] + embeddings.append(embedding[0].values) + + return embeddings # [END generativeaionvertexai_embedding] diff --git a/generative_ai/embeddings/test_embeddings_examples.py b/generative_ai/embeddings/test_embeddings_examples.py index afa350e50db..b4472d25a56 100644 --- a/generative_ai/embeddings/test_embeddings_examples.py +++ b/generative_ai/embeddings/test_embeddings_examples.py @@ -81,7 +81,7 @@ def test_generate_embeddings_with_lower_dimension() -> None: @backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) def test_text_embed_text() -> None: embeddings = document_retrieval_example.embed_text() - assert [len(e) for e in embeddings] == [256, 256] + assert [len(e) for e in embeddings] == [3072, 3072] @backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)