Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions tests/smoke/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile

# Third Party
from datasets import load_dataset
from transformers import AutoModelForCausalLM
import huggingface_hub
import pytest
Expand Down Expand Up @@ -46,9 +47,12 @@
"rdzv_endpoint": "127.0.0.1:12345",
}

REFERENCE_TEST_MODEL = "instructlab/granite-7b-lab"
REFERENCE_TEST_MODEL = "ibm-granite/granite-3.3-2b-instruct"
RUNNER_CPUS_EXPECTED = 4

# Number of samples to randomly sample from the processed dataset for faster training
NUM_SAMPLES_TO_KEEP = 5000


@pytest.fixture(scope="module")
def custom_tmp_dir() -> Generator[pathlib.Path, None, None]:
Expand Down Expand Up @@ -190,7 +194,10 @@ def chat_template_in_repo_path() -> pathlib.Path:
def cached_training_data(
prepared_data_dir: pathlib.Path, cached_test_model: pathlib.Path
) -> pathlib.Path:
"""Renders test data in model template, tokenizes, and saves to fs"""
"""
Renders test data in model template, tokenizes, and saves to filesystem.
Subsamples NUM_SAMPLES_TO_KEEP examples to speed up tests.
"""

data_in_repo = data_in_repo_path()
chat_template = chat_template_in_repo_path()
Expand All @@ -206,7 +213,19 @@ def cached_training_data(

data_process.main(data_process_args)

return prepared_data_dir / "data.jsonl"
# Load the processed data and sample a subset
output_path = prepared_data_dir / "data.jsonl"
dataset = load_dataset("json", data_files=str(output_path), split="train")

# Randomly sample NUM_SAMPLES_TO_KEEP examples
sampled_dataset = dataset.shuffle(seed=42).select(
range(min(NUM_SAMPLES_TO_KEEP, len(dataset)))
)

# Write the sampled data back to the same file
sampled_dataset.to_json(str(output_path), num_proc=RUNNER_CPUS_EXPECTED)

return output_path


@pytest.mark.slow
Expand Down
Loading