|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from unittest.mock import MagicMock, patch |
| 15 | +from unittest.mock import call, MagicMock, patch |
16 | 16 |
|
17 | 17 | from google.genai import types
|
18 | 18 |
|
19 | 19 | import tuning_job_create
|
20 | 20 | import tuning_job_get
|
21 | 21 | import tuning_job_list
|
22 | 22 | import tuning_textgen_with_txt
|
| 23 | +import tuning_with_checkpoints_create |
| 24 | +import tuning_with_checkpoints_get_model |
| 25 | +import tuning_with_checkpoints_list_checkpoints |
| 26 | +import tuning_with_checkpoints_set_default_checkpoint |
| 27 | +import tuning_with_checkpoints_textgen_with_txt |
23 | 28 |
|
24 | 29 |
|
25 | 30 | @patch("google.genai.Client")
|
@@ -113,3 +118,171 @@ def test_tuning_textgen_with_txt(mock_genai_client: MagicMock) -> None:
|
113 | 118 | mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
|
114 | 119 | mock_genai_client.return_value.tunings.get.assert_called_once()
|
115 | 120 | mock_genai_client.return_value.models.generate_content.assert_called_once()
|
| 121 | + |
| 122 | + |
| 123 | +@patch("google.genai.Client") |
| 124 | +def test_tuning_job_create_with_checkpoints(mock_genai_client: MagicMock) -> None: |
| 125 | + # Mock the API response |
| 126 | + mock_tuning_job = types.TuningJob( |
| 127 | + name="test-tuning-job", |
| 128 | + experiment="test-experiment", |
| 129 | + tuned_model=types.TunedModel( |
| 130 | + model="test-model", |
| 131 | + endpoint="test-endpoint-2", |
| 132 | + checkpoints=[ |
| 133 | + types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"), |
| 134 | + types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"), |
| 135 | + ] |
| 136 | + ) |
| 137 | + ) |
| 138 | + mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job |
| 139 | + |
| 140 | + response = tuning_with_checkpoints_create.create_with_checkpoints() |
| 141 | + |
| 142 | + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 143 | + mock_genai_client.return_value.tunings.tune.assert_called_once() |
| 144 | + assert response == "test-tuning-job" |
| 145 | + |
| 146 | + |
| 147 | +@patch("google.genai.Client") |
| 148 | +def test_tuning_with_checkpoints_get_model(mock_genai_client: MagicMock) -> None: |
| 149 | + # Mock the API response |
| 150 | + mock_tuning_job = types.TuningJob( |
| 151 | + name="test-tuning-job", |
| 152 | + experiment="test-experiment", |
| 153 | + tuned_model=types.TunedModel( |
| 154 | + model="test-model", |
| 155 | + endpoint="test-endpoint-2", |
| 156 | + checkpoints=[ |
| 157 | + types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"), |
| 158 | + types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"), |
| 159 | + ] |
| 160 | + ) |
| 161 | + ) |
| 162 | + mock_model = types.Model( |
| 163 | + name="test-model", |
| 164 | + default_checkpoint_id="2", |
| 165 | + checkpoints=[ |
| 166 | + types.Checkpoint(checkpoint_id="1", epoch=1, step=10), |
| 167 | + types.Checkpoint(checkpoint_id="2", epoch=2, step=20), |
| 168 | + ] |
| 169 | + ) |
| 170 | + mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job |
| 171 | + mock_genai_client.return_value.models.get.return_value = mock_model |
| 172 | + |
| 173 | + response = tuning_with_checkpoints_get_model.get_tuned_model_with_checkpoints("test-tuning-job") |
| 174 | + |
| 175 | + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 176 | + mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job") |
| 177 | + mock_genai_client.return_value.models.get.assert_called_once_with(model="test-model") |
| 178 | + assert response == "test-model" |
| 179 | + |
| 180 | + |
| 181 | +@patch("google.genai.Client") |
| 182 | +def test_tuning_with_checkpoints_list_checkpoints(mock_genai_client: MagicMock) -> None: |
| 183 | + # Mock the API response |
| 184 | + mock_tuning_job = types.TuningJob( |
| 185 | + name="test-tuning-job", |
| 186 | + experiment="test-experiment", |
| 187 | + tuned_model=types.TunedModel( |
| 188 | + model="test-model", |
| 189 | + endpoint="test-endpoint-2", |
| 190 | + checkpoints=[ |
| 191 | + types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"), |
| 192 | + types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"), |
| 193 | + ] |
| 194 | + ) |
| 195 | + ) |
| 196 | + mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job |
| 197 | + |
| 198 | + response = tuning_with_checkpoints_list_checkpoints.list_checkpoints("test-tuning-job") |
| 199 | + |
| 200 | + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 201 | + mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job") |
| 202 | + assert response == "test-tuning-job" |
| 203 | + |
| 204 | + |
| 205 | +@patch("google.genai.Client") |
| 206 | +def test_tuning_with_checkpoints_set_default_checkpoint(mock_genai_client: MagicMock) -> None: |
| 207 | + # Mock the API response |
| 208 | + mock_tuning_job = types.TuningJob( |
| 209 | + name="test-tuning-job", |
| 210 | + experiment="test-experiment", |
| 211 | + tuned_model=types.TunedModel( |
| 212 | + model="test-model", |
| 213 | + endpoint="test-endpoint-2", |
| 214 | + checkpoints=[ |
| 215 | + types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"), |
| 216 | + types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"), |
| 217 | + ] |
| 218 | + ) |
| 219 | + ) |
| 220 | + mock_model = types.Model( |
| 221 | + name="test-model", |
| 222 | + default_checkpoint_id="2", |
| 223 | + checkpoints=[ |
| 224 | + types.Checkpoint(checkpoint_id="1", epoch=1, step=10), |
| 225 | + types.Checkpoint(checkpoint_id="2", epoch=2, step=20), |
| 226 | + ] |
| 227 | + ) |
| 228 | + mock_updated_model = types.Model( |
| 229 | + name="test-model", |
| 230 | + default_checkpoint_id="1", |
| 231 | + checkpoints=[ |
| 232 | + types.Checkpoint(checkpoint_id="1", epoch=1, step=10), |
| 233 | + types.Checkpoint(checkpoint_id="2", epoch=2, step=20), |
| 234 | + ] |
| 235 | + ) |
| 236 | + mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job |
| 237 | + mock_genai_client.return_value.models.get.return_value = mock_model |
| 238 | + mock_genai_client.return_value.models.update.return_value = mock_updated_model |
| 239 | + |
| 240 | + response = tuning_with_checkpoints_set_default_checkpoint.set_default_checkpoint("test-tuning-job", "1") |
| 241 | + |
| 242 | + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 243 | + mock_genai_client.return_value.tunings.get.assert_called_once_with(name="test-tuning-job") |
| 244 | + mock_genai_client.return_value.models.get.assert_called_once_with(model="test-model") |
| 245 | + mock_genai_client.return_value.models.update.assert_called_once() |
| 246 | + assert response == "1" |
| 247 | + |
| 248 | + |
| 249 | +@patch("google.genai.Client") |
| 250 | +def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock) -> None: |
| 251 | + # Mock the API response |
| 252 | + mock_tuning_job = types.TuningJob( |
| 253 | + name="test-tuning-job", |
| 254 | + experiment="test-experiment", |
| 255 | + tuned_model=types.TunedModel( |
| 256 | + model="test-model", |
| 257 | + endpoint="test-endpoint-2", |
| 258 | + checkpoints=[ |
| 259 | + types.TunedModelCheckpoint(checkpoint_id="1", epoch=1, step=10, endpoint="test-endpoint-1"), |
| 260 | + types.TunedModelCheckpoint(checkpoint_id="2", epoch=2, step=20, endpoint="test-endpoint-2"), |
| 261 | + ] |
| 262 | + ) |
| 263 | + ) |
| 264 | + mock_response = types.GenerateContentResponse._from_response( # pylint: disable=protected-access |
| 265 | + response={ |
| 266 | + "candidates": [ |
| 267 | + { |
| 268 | + "content": { |
| 269 | + "parts": [{"text": "This is a mocked answer."}] |
| 270 | + } |
| 271 | + } |
| 272 | + ] |
| 273 | + }, |
| 274 | + kwargs={}, |
| 275 | + ) |
| 276 | + |
| 277 | + mock_genai_client.return_value.tunings.get.return_value = mock_tuning_job |
| 278 | + mock_genai_client.return_value.models.generate_content.return_value = mock_response |
| 279 | + |
| 280 | + tuning_with_checkpoints_textgen_with_txt.test_checkpoint("test-tuning-job") |
| 281 | + |
| 282 | + mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 283 | + mock_genai_client.return_value.tunings.get.assert_called_once() |
| 284 | + assert mock_genai_client.return_value.models.generate_content.call_args_list == [ |
| 285 | + call(model="test-endpoint-2", contents="Why is the sky blue?"), |
| 286 | + call(model="test-endpoint-1", contents="Why is the sky blue?"), |
| 287 | + call(model="test-endpoint-2", contents="Why is the sky blue?"), |
| 288 | + ] |
0 commit comments