diff --git a/docs/templates/toc.yml b/docs/templates/toc.yml index bab4ad9aac..47d9e97d7a 100644 --- a/docs/templates/toc.yml +++ b/docs/templates/toc.yml @@ -157,6 +157,8 @@ uid: bigframes.ml.llm.PaLM2TextGenerator - name: PaLM2TextEmbeddingGenerator uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator + - name: TextEmbeddingGenerator + uid: bigframes.ml.llm.TextEmbeddingGenerator - name: Claude3TextGenerator uid: bigframes.ml.llm.Claude3TextGenerator name: llm diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 1647eb879f..0c8a1956db 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -405,14 +405,15 @@ def test_gemini_text_generator_predict_with_params_success( assert all(series.str.len() > 20) -# TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions. @pytest.mark.parametrize( "model_name", - ("claude-3-sonnet", "claude-3-haiku"), + ("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"), ) def test_claude3_text_generator_create_load( - dataset_id, model_name, session, bq_connection + dataset_id, model_name, session, session_us_east5, bq_connection ): + if model_name in ("claude-3-5-sonnet", "claude-3-opus"): + session = session_us_east5 claude3_text_generator_model = llm.Claude3TextGenerator( model_name=model_name, connection_name=bq_connection, session=session ) @@ -430,12 +431,14 @@ def test_claude3_text_generator_create_load( @pytest.mark.parametrize( "model_name", - ("claude-3-sonnet", "claude-3-haiku"), + ("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"), ) @pytest.mark.flaky(retries=2) def test_claude3_text_generator_predict_default_params_success( - llm_text_df, model_name, session, bq_connection + llm_text_df, model_name, session, session_us_east5, bq_connection ): + if model_name in ("claude-3-5-sonnet", "claude-3-opus"): + session = session_us_east5 claude3_text_generator_model = llm.Claude3TextGenerator( model_name=model_name, connection_name=bq_connection, session=session ) @@ -448,12 +451,14 @@ def test_claude3_text_generator_predict_default_params_success( @pytest.mark.parametrize( "model_name", - ("claude-3-sonnet", "claude-3-haiku"), + ("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"), ) @pytest.mark.flaky(retries=2) def test_claude3_text_generator_predict_with_params_success( - llm_text_df, model_name, session, bq_connection + llm_text_df, model_name, session, session_us_east5, bq_connection ): + if model_name in ("claude-3-5-sonnet", "claude-3-opus"): + session = session_us_east5 claude3_text_generator_model = llm.Claude3TextGenerator( model_name=model_name, connection_name=bq_connection, session=session )