diff --git a/tests/smoke/test_train.py b/tests/smoke/test_train.py index d7e81cc9..ab8057fb 100644 --- a/tests/smoke/test_train.py +++ b/tests/smoke/test_train.py @@ -7,6 +7,7 @@ import tempfile # Third Party +from datasets import load_dataset from transformers import AutoModelForCausalLM import huggingface_hub import pytest @@ -46,9 +47,12 @@ "rdzv_endpoint": "127.0.0.1:12345", } -REFERENCE_TEST_MODEL = "instructlab/granite-7b-lab" +REFERENCE_TEST_MODEL = "ibm-granite/granite-3.3-2b-instruct" RUNNER_CPUS_EXPECTED = 4 +# Number of samples to randomly sample from the processed dataset for faster training +NUM_SAMPLES_TO_KEEP = 5000 + @pytest.fixture(scope="module") def custom_tmp_dir() -> Generator[pathlib.Path, None, None]: @@ -190,7 +194,10 @@ def chat_template_in_repo_path() -> pathlib.Path: def cached_training_data( prepared_data_dir: pathlib.Path, cached_test_model: pathlib.Path ) -> pathlib.Path: - """Renders test data in model template, tokenizes, and saves to fs""" + """ + Renders test data in model template, tokenizes, and saves to filesystem. + Subsamples NUM_SAMPLES_TO_KEEP examples to speed up tests. + """ data_in_repo = data_in_repo_path() chat_template = chat_template_in_repo_path() @@ -206,7 +213,19 @@ def cached_training_data( data_process.main(data_process_args) - return prepared_data_dir / "data.jsonl" + # Load the processed data and sample a subset + output_path = prepared_data_dir / "data.jsonl" + dataset = load_dataset("json", data_files=str(output_path), split="train") + + # Randomly sample NUM_SAMPLES_TO_KEEP examples + sampled_dataset = dataset.shuffle(seed=42).select( + range(min(NUM_SAMPLES_TO_KEEP, len(dataset))) + ) + + # Write the sampled data back to the same file + sampled_dataset.to_json(str(output_path), num_proc=RUNNER_CPUS_EXPECTED) + + return output_path @pytest.mark.slow