diff --git a/generative_ai/embedding_model_tuning.py b/generative_ai/embedding_model_tuning.py index 55f58d17943..f8e48c1ecd0 100644 --- a/generative_ai/embedding_model_tuning.py +++ b/generative_ai/embedding_model_tuning.py @@ -31,10 +31,10 @@ def tune_embedding_model( corpus_path: str = "gs://embedding-customization-pipeline/dataset/corpus.jsonl", train_label_path: str = "gs://embedding-customization-pipeline/dataset/train.tsv", test_label_path: str = "gs://embedding-customization-pipeline/dataset/test.tsv", - batch_size: int = 50, - iterations: int = 300, + batch_size: int = 128, + iterations: int = 1000, ) -> pipeline_jobs.PipelineJob: - match = re.search(r"(.+)(-autopush|-staging)?-aiplatform.+", api_endpoint) + match = re.search(r"^(\w+-\w+)", api_endpoint) location = match.group(1) if match else "us-central1" job = aiplatform.PipelineJob( display_name=pipeline_job_display_name,