Skip to content

Commit 286361a

Browse files
yishan-puYishan Pu
and
Yishan Pu
authored
feat: add code samples for tuning with intermediate checkpoints (#13366)
* feat: add code samples for tuning with intermediate checkpoints * feat: add code samples for tuning with intermediate checkpoints --------- Co-authored-by: Yishan Pu <[email protected]>
1 parent 56e38df commit 286361a

8 files changed

+445
-2
lines changed

genai/tuning/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-genai==1.7.0
1+
google-genai==1.15.0

genai/tuning/test_tuning_examples.py

+174-1
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from unittest.mock import MagicMock, patch
15+
from unittest.mock import call, MagicMock, patch
1616

1717
from google.genai import types
1818

1919
import tuning_job_create
2020
import tuning_job_get
2121
import tuning_job_list
2222
import tuning_textgen_with_txt
23+
import tuning_with_checkpoints_create
24+
import tuning_with_checkpoints_get_model
25+
import tuning_with_checkpoints_list_checkpoints
26+
import tuning_with_checkpoints_set_default_checkpoint
27+
import tuning_with_checkpoints_textgen_with_txt
2328

2429

2530
@patch("google.genai.Client")
@@ -113,3 +118,171 @@ def test_tuning_textgen_with_txt(mock_genai_client: MagicMock) -> None:
113118
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
114119
mock_genai_client.return_value.tunings.get.assert_called_once()
115120
mock_genai_client.return_value.models.generate_content.assert_called_once()
121+
122+
123+
@patch("google.genai.Client")
124+
def test_tuning_job_create_with_checkpoints(mock_genai_client: MagicMock) -> None:
125+
# Mock the API response
126+
mock_tuning_job = types.TuningJob(
127+
name="test-tuning-job",
128+
experiment="test-experiment",
129+
tuned_model=types.TunedModel(
130+
model="test-model",
131+
endpoint="test-endpoint-2",
132+
checkpoints=[
133+
types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"),
134+
types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"),
135+
]
136+
)
137+
)
138+
mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job
139+
140+
response = tuning_with_checkpoints_create.create_with_checkpoints()
141+
142+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
143+
mock_genai_client.return_value.tunings.tune.assert_called_once()
144+
assert response == "test-tuning-job"
145+
146+
147+
@patch("google.genai.Client")
148+
def test_tuning_with_checkpoints_get_model(mock_genai_client: MagicMock) -> None:
149+
# Mock the API response
150+
mock_tuning_job = types.TuningJob(
151+
name="test-tuning-job",
152+
experiment="test-experiment",
153+
tuned_model=types.TunedModel(
154+
model="test-model",
155+
endpoint="test-endpoint-2",
156+
checkpoints=[
157+
types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"),
158+
types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"),
159+
]
160+
)
161+
)
162+
mock_model = types.Model(
163+
name="test-model",
164+
default_checkpoint_id="2",
165+
checkpoints=[
166+
types.Checkpoint(checkpoint_id="1", epoch=1, step=10),
167+
types.Checkpoint(checkpoint_id="2", epoch=2, step=20),
168+
]
169+
)
170+
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
171+
mock_genai_client.return_value.models.get.return_value = mock_model
172+
173+
response = tuning_with_checkpoints_get_model.get_tuned_model_with_checkpoints("test-tuning-job")
174+
175+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
176+
mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job")
177+
mock_genai_client.return_value.models.get.assert_called_once_with(model="test-model")
178+
assert response == "test-model"
179+
180+
181+
@patch("google.genai.Client")
182+
def test_tuning_with_checkpoints_list_checkpoints(mock_genai_client: MagicMock) -> None:
183+
# Mock the API response
184+
mock_tuning_job = types.TuningJob(
185+
name="test-tuning-job",
186+
experiment="test-experiment",
187+
tuned_model=types.TunedModel(
188+
model="test-model",
189+
endpoint="test-endpoint-2",
190+
checkpoints=[
191+
types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"),
192+
types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"),
193+
]
194+
)
195+
)
196+
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
197+
198+
response = tuning_with_checkpoints_list_checkpoints.list_checkpoints("test-tuning-job")
199+
200+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
201+
mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job")
202+
assert response == "test-tuning-job"
203+
204+
205+
@patch("google.genai.Client")
206+
def test_tuning_with_checkpoints_set_default_checkpoint(mock_genai_client: MagicMock) -> None:
207+
# Mock the API response
208+
mock_tuning_job = types.TuningJob(
209+
name="test-tuning-job",
210+
experiment="test-experiment",
211+
tuned_model=types.TunedModel(
212+
model="test-model",
213+
endpoint="test-endpoint-2",
214+
checkpoints=[
215+
types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"),
216+
types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"),
217+
]
218+
)
219+
)
220+
mock_model = types.Model(
221+
name="test-model",
222+
default_checkpoint_id="2",
223+
checkpoints=[
224+
types.Checkpoint(checkpoint_id="1", epoch=1, step=10),
225+
types.Checkpoint(checkpoint_id="2", epoch=2, step=20),
226+
]
227+
)
228+
mock_updated_model = types.Model(
229+
name="test-model",
230+
default_checkpoint_id="1",
231+
checkpoints=[
232+
types.Checkpoint(checkpoint_id="1", epoch=1, step=10),
233+
types.Checkpoint(checkpoint_id="2", epoch=2, step=20),
234+
]
235+
)
236+
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
237+
mock_genai_client.return_value.models.get.return_value = mock_model
238+
mock_genai_client.return_value.models.update.return_value = mock_updated_model
239+
240+
response = tuning_with_checkpoints_set_default_checkpoint.set_default_checkpoint("test-tuning-job", "1")
241+
242+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
243+
mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job")
244+
mock_genai_client.return_value.models.get.assert_called_once_with(model="test-model")
245+
mock_genai_client.return_value.models.update.assert_called_once()
246+
assert response == "1"
247+
248+
249+
@patch("google.genai.Client")
250+
def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock) -> None:
251+
# Mock the API response
252+
mock_tuning_job = types.TuningJob(
253+
name="test-tuning-job",
254+
experiment="test-experiment",
255+
tuned_model=types.TunedModel(
256+
model="test-model",
257+
endpoint="test-endpoint-2",
258+
checkpoints=[
259+
types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"),
260+
types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"),
261+
]
262+
)
263+
)
264+
mock_response = types.GenerateContentResponse._from_response( # pylint: disable=protected-access
265+
response={
266+
"candidates": [
267+
{
268+
"content": {
269+
"parts": [{"text": "This is a mocked answer."}]
270+
}
271+
}
272+
]
273+
},
274+
kwargs={},
275+
)
276+
277+
mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job
278+
mock_genai_client.return_value.models.generate_content.return_value = mock_response
279+
280+
tuning_with_checkpoints_textgen_with_txt.test_checkpoint("test-tuning-job")
281+
282+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
283+
mock_genai_client.return_value.tunings.get.assert_called_once()
284+
assert mock_genai_client.return_value.models.generate_content.call_args_list == [
285+
call(model="test-endpoint-2", contents="Why is the sky blue?"),
286+
call(model="test-endpoint-1", contents="Why is the sky blue?"),
287+
call(model="test-endpoint-2", contents="Why is the sky blue?"),
288+
]

genai/tuning/tuning_job_create.py

+7
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def create_tuning_job() -> str:
4848
# projects/123456789012/locations/us-central1/endpoints/123456789012345
4949
# projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678
5050

51+
if tuning_job.tuned_model.checkpoints:
52+
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
53+
print(f"Checkpoint {i + 1}: ", checkpoint)
54+
# Example response:
55+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000'
56+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345'
57+
5158
# [END googlegenaisdk_tuning_job_create]
5259
return tuning_job.name
5360

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_with_checkpoints() -> str:
17+
# [START googlegenaisdk_tuning_with_checkpoints_create]
18+
import time
19+
20+
from google import genai
21+
from google.genai.types import HttpOptions, CreateTuningJobConfig
22+
23+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
24+
25+
tuning_job = client.tunings.tune(
26+
base_model="gemini-2.0-flash-lite-001",
27+
training_dataset="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
28+
config=CreateTuningJobConfig(
29+
tuned_model_display_name="Example tuning job",
30+
# Set to True to disable tuning intermediate checkpoints. Default is False.
31+
export_last_checkpoint_only=False,
32+
),
33+
)
34+
35+
running_states = set([
36+
"JOB_STATE_PENDING",
37+
"JOB_STATE_RUNNING",
38+
])
39+
40+
while tuning_job.state in running_states:
41+
print(tuning_job.state)
42+
tuning_job = client.tunings.get(name=tuning_job.name)
43+
time.sleep(60)
44+
45+
print(tuning_job.tuned_model.model)
46+
print(tuning_job.tuned_model.endpoint)
47+
print(tuning_job.experiment)
48+
# Example response:
49+
# projects/123456789012/locations/us-central1/models/1234567890@1
50+
# projects/123456789012/locations/us-central1/endpoints/123456789012345
51+
# projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678
52+
53+
if tuning_job.tuned_model.checkpoints:
54+
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
55+
print(f"Checkpoint {i + 1}: ", checkpoint)
56+
# Example response:
57+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000'
58+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345'
59+
60+
# [END googlegenaisdk_tuning_with_checkpoints_create]
61+
return tuning_job.name
62+
63+
64+
if __name__ == "__main__":
65+
create_with_checkpoints()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def get_tuned_model_with_checkpoints(name: str) -> str:
17+
# [START googlegenaisdk_tuning_with_checkpoints_get_model]
18+
from google import genai
19+
from google.genai.types import HttpOptions
20+
21+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
22+
23+
# Get the tuning job and the tuned model.
24+
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=name)
26+
tuned_model = client.models.get(model=tuning_job.tuned_model.model)
27+
print(tuned_model)
28+
# Example response:
29+
# Model(name='projects/123456789012/locations/us-central1/models/1234567890@1', ...)
30+
31+
print(f"Default checkpoint: {tuned_model.default_checkpoint_id}")
32+
# Example response:
33+
# Default checkpoint: 2
34+
35+
if tuned_model.checkpoints:
36+
for _, checkpoint in enumerate(tuned_model.checkpoints):
37+
print(f"Checkpoint {checkpoint.checkpoint_id}: ", checkpoint)
38+
# Example response:
39+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10
40+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20
41+
42+
# [END googlegenaisdk_tuning_with_checkpoints_get_model]
43+
return tuned_model.name
44+
45+
46+
if __name__ == "__main__":
47+
tuning_job_name = input("Tuning job name: ")
48+
get_tuned_model_with_checkpoints(tuning_job_name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def list_checkpoints(name: str) -> str:
17+
# [START googlegenaisdk_tuning_with_checkpoints_list_checkpoints]
18+
from google import genai
19+
from google.genai.types import HttpOptions
20+
21+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
22+
23+
# Get the tuning job and the tuned model.
24+
# Eg. name = "projects/123456789012/locations/us-central1/tuningJobs/123456789012345"
25+
tuning_job = client.tunings.get(name=name)
26+
27+
if tuning_job.tuned_model.checkpoints:
28+
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
29+
print(f"Checkpoint {i + 1}: ", checkpoint)
30+
# Example response:
31+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000'
32+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345'
33+
34+
# [END googlegenaisdk_tuning_with_checkpoints_list_checkpoints]
35+
return tuning_job.name
36+
37+
38+
if __name__ == "__main__":
39+
tuning_job_name = input("Tuning job name: ")
40+
list_checkpoints(tuning_job_name)

0 commit comments

Comments
 (0)