Skip to content

Commit 64044e6

Browse files
authored
test: add claude models tests in us-east5 (#908)
* test: add claude models tests in us-east5 * docs fix
1 parent 3031903 commit 64044e6

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

docs/templates/toc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
uid: bigframes.ml.llm.PaLM2TextGenerator
158158
- name: PaLM2TextEmbeddingGenerator
159159
uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator
160+
- name: TextEmbeddingGenerator
161+
uid: bigframes.ml.llm.TextEmbeddingGenerator
160162
- name: Claude3TextGenerator
161163
uid: bigframes.ml.llm.Claude3TextGenerator
162164
name: llm

tests/system/small/ml/test_llm.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,15 @@ def test_gemini_text_generator_predict_with_params_success(
405405
assert all(series.str.len() > 20)
406406

407407

408-
# TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions.
409408
@pytest.mark.parametrize(
410409
"model_name",
411-
("claude-3-sonnet", "claude-3-haiku"),
410+
("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"),
412411
)
413412
def test_claude3_text_generator_create_load(
414-
dataset_id, model_name, session, bq_connection
413+
dataset_id, model_name, session, session_us_east5, bq_connection
415414
):
415+
if model_name in ("claude-3-5-sonnet", "claude-3-opus"):
416+
session = session_us_east5
416417
claude3_text_generator_model = llm.Claude3TextGenerator(
417418
model_name=model_name, connection_name=bq_connection, session=session
418419
)
@@ -430,12 +431,14 @@ def test_claude3_text_generator_create_load(
430431

431432
@pytest.mark.parametrize(
432433
"model_name",
433-
("claude-3-sonnet", "claude-3-haiku"),
434+
("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"),
434435
)
435436
@pytest.mark.flaky(retries=2)
436437
def test_claude3_text_generator_predict_default_params_success(
437-
llm_text_df, model_name, session, bq_connection
438+
llm_text_df, model_name, session, session_us_east5, bq_connection
438439
):
440+
if model_name in ("claude-3-5-sonnet", "claude-3-opus"):
441+
session = session_us_east5
439442
claude3_text_generator_model = llm.Claude3TextGenerator(
440443
model_name=model_name, connection_name=bq_connection, session=session
441444
)
@@ -448,12 +451,14 @@ def test_claude3_text_generator_predict_default_params_success(
448451

449452
@pytest.mark.parametrize(
450453
"model_name",
451-
("claude-3-sonnet", "claude-3-haiku"),
454+
("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"),
452455
)
453456
@pytest.mark.flaky(retries=2)
454457
def test_claude3_text_generator_predict_with_params_success(
455-
llm_text_df, model_name, session, bq_connection
458+
llm_text_df, model_name, session, session_us_east5, bq_connection
456459
):
460+
if model_name in ("claude-3-5-sonnet", "claude-3-opus"):
461+
session = session_us_east5
457462
claude3_text_generator_model = llm.Claude3TextGenerator(
458463
model_name=model_name, connection_name=bq_connection, session=session
459464
)

0 commit comments

Comments
 (0)