@@ -405,14 +405,15 @@ def test_gemini_text_generator_predict_with_params_success(
405405 assert all (series .str .len () > 20 )
406406
407407
408- # TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions.
409408@pytest .mark .parametrize (
410409 "model_name" ,
411- ("claude-3-sonnet" , "claude-3-haiku" ),
410+ ("claude-3-sonnet" , "claude-3-haiku" , "claude-3-5-sonnet" , "claude-3-opus" ),
412411)
413412def test_claude3_text_generator_create_load (
414- dataset_id , model_name , session , bq_connection
413+ dataset_id , model_name , session , session_us_east5 , bq_connection
415414):
415+ if model_name in ("claude-3-5-sonnet" , "claude-3-opus" ):
416+ session = session_us_east5
416417 claude3_text_generator_model = llm .Claude3TextGenerator (
417418 model_name = model_name , connection_name = bq_connection , session = session
418419 )
@@ -430,12 +431,14 @@ def test_claude3_text_generator_create_load(
430431
431432@pytest .mark .parametrize (
432433 "model_name" ,
433- ("claude-3-sonnet" , "claude-3-haiku" ),
434+ ("claude-3-sonnet" , "claude-3-haiku" , "claude-3-5-sonnet" , "claude-3-opus" ),
434435)
435436@pytest .mark .flaky (retries = 2 )
436437def test_claude3_text_generator_predict_default_params_success (
437- llm_text_df , model_name , session , bq_connection
438+ llm_text_df , model_name , session , session_us_east5 , bq_connection
438439):
440+ if model_name in ("claude-3-5-sonnet" , "claude-3-opus" ):
441+ session = session_us_east5
439442 claude3_text_generator_model = llm .Claude3TextGenerator (
440443 model_name = model_name , connection_name = bq_connection , session = session
441444 )
@@ -448,12 +451,14 @@ def test_claude3_text_generator_predict_default_params_success(
448451
449452@pytest .mark .parametrize (
450453 "model_name" ,
451- ("claude-3-sonnet" , "claude-3-haiku" ),
454+ ("claude-3-sonnet" , "claude-3-haiku" , "claude-3-5-sonnet" , "claude-3-opus" ),
452455)
453456@pytest .mark .flaky (retries = 2 )
454457def test_claude3_text_generator_predict_with_params_success (
455- llm_text_df , model_name , session , bq_connection
458+ llm_text_df , model_name , session , session_us_east5 , bq_connection
456459):
460+ if model_name in ("claude-3-5-sonnet" , "claude-3-opus" ):
461+ session = session_us_east5
457462 claude3_text_generator_model = llm .Claude3TextGenerator (
458463 model_name = model_name , connection_name = bq_connection , session = session
459464 )
0 commit comments