Overview
A production-ready template for training LLMs with two strategies:
- FSDP for multi-GPU sharded training and sharded checkpoints
- Unsloth for efficient 4-bit fine-tuning on a single/multi-GPU node
It is designed for AWS SageMaker Spot with preemption-safe saves (SIGTERM handling) and automatic resume from the latest checkpoint. Includes hardened dataloading (HF Hub and S3 parquet) and robust checkpoint sync to S3.
For non-developers (quick start): use the provided commands without changing code. For developers: code lives under src/ in a standard Python package layout.
- FSDP strategy with sharded initialization and sharded checkpoints
- Unsloth strategy with 4-bit training and a fallback to Transformers
- Preemption-safe: SIGTERM/SIGINT triggers
emergency_stop()with a final checkpoint - Auto-resume: if
--resumeis not provided, uses the most recent checkpoint incheckpoint.output_dir - Hardened dataloader: HF datasets or S3 parquet, batched mapping with prompt templating
- S3 checkpoint sync with retries and exponential backoff
- Dockerfile with CUDA PyTorch,
awscli, ands5cmdpreinstalled
src/fsdp_unsloth/core/trainers, strategy selection, security checkscommon/logging, memory, checkpoint utils, and config adapter/schema
scripts/train.pyCLI entry (thin wrapper; you can also use the installed CLI)infer.pyexample inference scriptconfigs/example configs (FSDP/Unsloth + smoke)
.github/workflows/GitHub Actions for CI and pre-commit
python -m venv .venv
. .venv/bin/activate
pip install uv
uv pip install -e ".[dev]"
pre-commit install- A template notebook is provided at
notebooks/secure_submit.ipynbwhich demonstrates:- Building SageMaker guardrails via
scripts/core/security.py::build_sagemaker_guardrails() - Redacting secrets before logging configs
- Merging guardrails into a job request (example
boto3call commented out)
- Building SageMaker guardrails via
- Configure environment values using
.env.example(copy to.env). - Optional: install
python-dotenv(already in requirements) and load env in your scripts:from dotenv import load_dotenv load_dotenv()
-
src/fsdp_unsloth/core/strategy_selector.pyruns preflight checks (HF token format, S3/local path safety, W&B readiness) before trainer construction. -
Enable strict mode to fail fast:
security: strict_preflight: true
-
Docker image (recommended for SageMaker)
docker build -t unsloth-fsdp-training:latest .- Base schema:
src/fsdp_unsloth/common/configs/base_config.yaml - Examples:
- FSDP:
scripts/configs/fsdp/llama-7b.yaml - Unsloth:
scripts/configs/unsloth/finance-alpaca.yaml - Smoke tests:
scripts/configs/{fsdp,unsloth}/smoke.yaml
- FSDP:
Backend selection is explicit:
- Set
backend: fsdporbackend: unslothat the top of the config. - CLI override available via
--backend(alias of--strategy).
Key fields:
training.*(batch sizes, lr, steps)checkpoint.save_interval,checkpoint.output_dirlogging.log_interval,logging.wandb_projectmodel.name,model.max_length,model.load_in_4bit,model.hf_tokenfsdp.mixed_precisionand other sharding params
- Using the installed CLI (recommended):
fsdp-train --config scripts/configs/fsdp/smoke.yaml --smoke
fsdp-train --config scripts/configs/unsloth/smoke.yaml --backend unsloth --smoke- Via provided script wrapper (equivalent):
python scripts/train.py --config scripts/configs/fsdp/llama-7b.yaml- Multi-GPU (torchrun):
make train-fsdp-mgpu NGPU=8
make train-unsloth-mgpu NGPU=8Optional NCCL hints in Makefile (commented) for multi-node networking.
- FSDP saves sharded checkpoints into folders like
checkpoint_<step>/undercheckpoint.output_dir. - Unsloth saves a single-file checkpoint
checkpoint_<step>.bin. - Auto-resume (when
--resumenot provided): auto-detects the latest checkpoint incheckpoint.output_dir. - SageMaker: set
CheckpointConfig(S3 URI). The trainer will sync toSM_CHECKPOINT_DIRautomatically.
- HF dataset:
data.name= HF dataset ID, supports streaming. - S3 parquet:
data.name=s3://bucket/path/file.parquet(parquet only). Usess3fs. - Prompt templating: define
data.prompt_templateusing{instruction},{input},{output},{eos_token}.
- Spot preemption triggers SIGTERM; the trainer catches it and performs an emergency checkpoint save.
- Recommended GPU instances:
- FSDP:
p4d.24xlarge(A100, 8x GPU) orp5.48xlarge(H100) for larger models - Unsloth:
g5.12xlarge(A10G) orp4d.24xlargedepending on model size
- FSDP:
- Use
CheckpointConfigfor S3 checkpointing and enable Managed Spot Training. EnsureMaxWaitTimeInSeconds > MaxRuntimeInSecondsfor queueing.
- Minimal runs to validate wiring and error handling.
make train-unsloth-smoke
make train-fsdp-smoke # requires GPUUse scripts/tools/convert_checkpoint.py to convert between FSDP sharded directories and Unsloth single-file checkpoints.
- Convert FSDP shards to a single Unsloth file:
python -m scripts.tools.convert_checkpoint \
--source_path outputs/checkpoint_1000 \
--target_path outputs/unsloth_1000.bin \
--strategy fsdp --target_strategy unsloth- Convert Unsloth file to FSDP shards directory:
python -m scripts.tools.convert_checkpoint \
--source_path outputs/unsloth_1000.bin \
--target_path outputs/fsdp_1000 \
--strategy unsloth --target_strategy fsdp- Inference with an optional checkpoint (file or shard dir):
python scripts/infer.py \
--config scripts/configs/unsloth/smoke.yaml \
--prompt "Hello" \
--checkpoint outputs/unsloth_1000.bin- Prereqs
- Python 3.10+, CUDA drivers for GPU runs
- HF credentials (
HF_TOKEN) if using gated models/datasets - AWS credentials for S3 (optional for S3 paths)
- Workflow
- Branch from
main, implement changes - Run format/lint/tests:
pre-commit run --all-files pytest -v
- Submit a PR with a concise description and test plan
- Branch from
- This project is licensed under Apache License 2.0 (see
LICENSE).
- [docs] Add SageMaker job submission examples (Estimator config, Spot flags, CheckpointConfig)
- [fsdp] Add richer sharding options in
fsdpconfig (activation checkpointing policies, CPU offload) - [resume] Write a
latestpointer file after each save to speed up auto-resume discovery - [inference] Validate and document
scripts/infer.pyfor both strategies - [tests] Add CPU-only unit tests and a small CI workflow for lint + schema checks
- [monitoring] Add optional CloudWatch/W&B guidance and Makefile targets for metrics sync
- [datasets] Add JSONL and multi-file S3 dataset examples