Skip to content

Commit 164435e

Browse files
authored
Merge branch 'pytorch:main' into quant/int4/wo/0
2 parents 68eea61 + c5e7c18 commit 164435e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1952
-453
lines changed

.github/workflows/1xH100_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
gpu-arch-version: ${{ matrix.gpu-arch-version }}
4040
submodules: recursive
4141
script: |
42-
conda create -n venv python=3.9 -y
42+
conda create -n venv python=3.10 -y
4343
conda activate venv
4444
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
4545
python -m pip install --upgrade pip

.github/workflows/1xL4_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
gpu-arch-version: ${{ matrix.gpu-arch-version }}
4040
submodules: recursive
4141
script: |
42-
conda create -n venv python=3.9 -y
42+
conda create -n venv python=3.10 -y
4343
conda activate venv
4444
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
4545
python -m pip install --upgrade pip

.github/workflows/4xH100_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
gpu-arch-version: ${{ matrix.gpu-arch-version }}
3838
submodules: recursive
3939
script: |
40-
conda create -n venv python=3.9 -y
40+
conda create -n venv python=3.10 -y
4141
conda activate venv
4242
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
4343
python -m pip install --upgrade pip

.github/workflows/build_wheels_linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
with-rocm: enable
3333
with-xpu: enable
3434
# Note: if free-threaded python is required add py3.13t here
35-
python-versions: '["3.9"]'
35+
python-versions: '["3.10"]'
3636

3737
build:
3838
needs: generate-matrix

.github/workflows/doc_build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ jobs:
4343
- name: Install dependencies
4444
run: |
4545
python -m pip install torch
46+
python -m pip install setuptools==78.1.1 --force-reinstall
4647
python -m pip install -e .
4748
pip install -r dev-requirements.txt
4849
python -m pip install -r docs/requirements.txt

.github/workflows/regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
gpu-arch-version: ${{ matrix.gpu-arch-version }}
4646
submodules: recursive
4747
script: |
48-
conda create -n venv python=3.9 -y
48+
conda create -n venv python=3.10 -y
4949
conda activate venv
5050
python -m pip install --upgrade pip
5151
pip install ${{ matrix.torch-spec }}
@@ -105,7 +105,7 @@ jobs:
105105
gpu-arch-version: ${{ matrix.gpu-arch-version }}
106106
submodules: recursive
107107
script: |
108-
conda create -n venv python=3.9 -y
108+
conda create -n venv python=3.10 -y
109109
conda activate venv
110110
echo "::group::Install newer objcopy that supports --set-section-alignment"
111111
dnf install -y gcc-toolset-10-binutils

.github/workflows/regression_test_rocm.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ jobs:
2222
include:
2323
- name: ROCM Nightly
2424
runs-on: linux.rocm.gpu.gfx942.2
25-
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
25+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm7.0'
2626
gpu-arch-type: "rocm"
27-
gpu-arch-version: "6.3"
28-
docker-image: pytorch/manylinux2_28-builder:rocm6.3
27+
gpu-arch-version: "7.0"
28+
docker-image: pytorch/manylinux2_28-builder:rocm7.0
2929

3030
permissions:
3131
id-token: write
@@ -40,7 +40,7 @@ jobs:
4040
docker-image: ${{ matrix.docker-image }}
4141
submodules: recursive
4242
script: |
43-
conda create -n venv python=3.9 -y
43+
conda create -n venv python=3.10 -y
4444
conda activate venv
4545
python -m pip install --upgrade pip
4646
pip install ${{ matrix.torch-spec }}

README.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
- [May 25] QAT is now integrated into [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) for fine-tuning ([docs](https://docs.axolotl.ai/docs/qat.html))!
2929
- [Apr 25] Float8 rowwise training yielded [1.34-1.43x training speedup](https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/) at 2k H100 GPU scale
3030
- [Apr 25] TorchAO is added as a [quantization backend to vLLM](https://docs.vllm.ai/en/latest/features/quantization/torchao.html) ([docs](https://docs.vllm.ai/en/latest/features/quantization/torchao.html))!
31-
- [Mar 25] Our [2:4 Sparsity paper](https://openreview.net/pdf?id=O5feVk7p6Y) was accepted to SLLM @ ICLR 2025!
32-
- [Jan 25] Our [integration with GemLite and SGLang](https://pytorch.org/blog/accelerating-llm-inference/) yielded 1.1-2x faster inference with int4 and float8 quantization across different batch sizes and tensor parallel sizes
33-
- [Jan 25] We added [1-8 bit ARM CPU kernels](https://pytorch.org/blog/hi-po-low-bit-operators/) for linear and embedding ops
3431

3532
<details>
3633
<summary>Older news</summary>
3734

35+
- [Mar 25] Our [2:4 Sparsity paper](https://openreview.net/pdf?id=O5feVk7p6Y) was accepted to SLLM @ ICLR 2025!
36+
- [Jan 25] Our [integration with GemLite and SGLang](https://pytorch.org/blog/accelerating-llm-inference/) yielded 1.1-2x faster inference with int4 and float8 quantization across different batch sizes and tensor parallel sizes
37+
- [Jan 25] We added [1-8 bit ARM CPU kernels](https://pytorch.org/blog/hi-po-low-bit-operators/) for linear and embedding ops
3838
- [Nov 24] We achieved [1.43-1.51x faster pre-training](https://pytorch.org/blog/training-using-float8-fsdp2/) on Llama-3.1-70B and 405B using float8 training
3939
- [Oct 24] TorchAO is added as a quantization backend to HF Transformers!
4040
- [Sep 24] We officially launched TorchAO. Check out our blog [here](https://pytorch.org/blog/pytorch-native-architecture-optimization/)!
@@ -47,8 +47,7 @@
4747

4848
## 🌅 Overview
4949

50-
TorchAO is a PyTorch-native model optimization framework leveraging quantization and sparsity to provide an end-to-end, training-to-serving workflow
51-
for AI models. TorchAO works out-of-the-box with `torch.compile()` and `FSDP2` across most HuggingFace PyTorch models. Key features include:
50+
TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box with `torch.compile()` and `FSDP2` across most HuggingFace PyTorch models. Key features include:
5251
* Float8 [training](torchao/float8/README.md) and [inference](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) for speedups without compromising accuracy
5352
* [MX training and inference](torchao/prototype/mx_formats/README.md), provides MX tensor formats based on native PyTorch MX dtypes (prototype)
5453
* [Quantization-Aware Training (QAT)](torchao/quantization/qat/README.md) for mitigating quantization degradation
@@ -67,17 +66,17 @@ From the team that brought you the fast series:
6766
## 🚀 Quick Start
6867

6968
First, install TorchAO. We recommend installing the latest stable version:
70-
```
69+
```bash
7170
pip install torchao
7271
```
7372

7473
Quantize your model weights to int4!
75-
```
74+
```python
7675
from torchao.quantization import Int4WeightOnlyConfig, quantize_
7776
quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1))
7877
```
7978
Compared to a `torch.compiled` bf16 baseline, your quantized model should be significantly smaller and faster on a single A100 GPU:
80-
```
79+
```bash
8180
int4 model size: 1.25 MB
8281
bfloat16 model size: 4.00 MB
8382
compression ratio: 3.2
@@ -86,13 +85,13 @@ bf16 mean time: 30.393 ms
8685
int4 mean time: 4.410 ms
8786
speedup: 6.9x
8887
```
89-
For the full model setup and benchmark details, check out our [quick start guide](https://docs.pytorch.org/ao/stable/quick_start.html). Alternatively, try quantizing your favorite model using our [HuggingFace space](https://huggingface.co/spaces/pytorch/torchao-my-repo)!
88+
See our [quick start guide](https://docs.pytorch.org/ao/stable/quick_start.html) for more details. Alternatively, try quantizing your favorite model using our [HuggingFace space](https://huggingface.co/spaces/pytorch/torchao-my-repo)!
9089

9190

9291
## 🛠 Installation
9392

9493
To install the latest stable version:
95-
```
94+
```bash
9695
pip install torchao
9796
```
9897

@@ -196,7 +195,7 @@ quantize_(my_model, QATConfig(base_config, step="convert"))
196195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).
197196

198197

199-
### Float8
198+
### Quantized training
200199

201200
[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. With ``torch.compile`` on, current results show throughput speedups of up to **1.5x on up to 512 GPU / 405B parameter count scale** ([details](https://pytorch.org/blog/training-using-float8-fsdp2/)):
202201

@@ -211,6 +210,8 @@ Our float8 training is integrated into [TorchTitan's pre-training flows](https:/
211210
* [Efficient Pre-training of Llama 3-like model architectures using torchtitan on Amazon SageMaker](https://aws.amazon.com/blogs/machine-learning/efficient-pre-training-of-llama-3-like-model-architectures-using-torchtitan-on-amazon-sagemaker/)
212211
* [Float8 in PyTorch](https://dev-discuss.pytorch.org/t/float8-in-pytorch-1-x/1815)
213212

213+
<details>
214+
<summary>Other features (sparse training, memory efficient optimizers)</summary>
214215

215216
### Sparse Training
216217

@@ -242,6 +243,8 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
242243
optim.load_state_dict(ckpt["optim"])
243244
```
244245

246+
</details>
247+
245248
<!--
246249
## For Developers
247250
@@ -258,7 +261,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
258261
259262
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
260263
261-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
264+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
262265
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
263266
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
264267

benchmarks/float8/float8_roofline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ def run(
245245
bf16_gemm_time_sympy = get_gemm_time_sympy(
246246
M, K, N, torch.bfloat16, None, None, None
247247
)
248+
lowp_input_dtype = torch.float8_e4m3fn
249+
if mx_recipe_name == "mxfp4_cutlass":
250+
lowp_input_dtype = torch.float4_e2m1fn_x2
251+
248252
fp8_gemm_time_sympy = get_gemm_time_sympy(
249-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None
253+
M, K, N, lowp_input_dtype, float8_recipe_name, mx_recipe_name, None
250254
)
251255
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
252256
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
@@ -304,6 +308,8 @@ def run(
304308
rb_fp8_gemm_ratio = -1
305309

306310
if do_benchmarks:
311+
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
312+
307313
# TODO(future): make the bf16 gemm times exactly match the e2e
308314
# benchmarks, there is a slight deviation, probably related to gemm
309315
# operand memory formats/transpositions below not exactly matching

benchmarks/float8/training/llama4.sh renamed to benchmarks/float8/training/bench.sh

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
# This script can be used to launch a torchtitan float8 training run
88
# with the given parameters,
99

10-
# script arguments
11-
LOCAL_BATCH_SIZE=${LOCAL_BATCH_SIZE:-1}
12-
STEPS=${STEPS:-100}
13-
1410
# temporary log file which is deleted after performance data is parsed out and metrics are calculated.
15-
LOG_FILE="/tmp/float8_training_log.txt"
11+
LOG_FILE="/tmp/torchtitan_logs.txt"
1612

17-
# validate user has specified torchtitan root directory
13+
# validate user has specified required args
1814
if [ -z "${TORCHTITAN_ROOT}" ]; then
19-
echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script."
20-
echo "Usage: TORCHTITAN_ROOT=<directory> ./torchtitan_llama4.sh"
15+
echo "Error: TORCHTITAN_ROOT environment variable is not set. Please set it before running this script."
16+
echo "Usage: TORCHTITAN_ROOT=<directory> CONFIG_FILE=<model toml> ./moe.sh"
17+
echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script."
18+
exit 1
19+
fi
20+
21+
if [ -z "${CONFIG_FILE}" ]; then
22+
echo "Error: CONFIG_FILE environment variable is not set. Please set it before running this script."
23+
echo "Usage: TORCHTITAN_ROOT=<directory> CONFIG_FILE=<model toml> ./moe.sh"
2124
echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script."
2225
exit 1
2326
fi
@@ -29,7 +32,7 @@ original_dir=$(pwd)
2932
cd ${TORCHTITAN_ROOT}
3033

3134
# run the command with the specified arguments
32-
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ${TORCHTITAN_ROOT}/run_train.sh ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE}
35+
${TORCHTITAN_ROOT}/run_train.sh ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE}
3336

3437
# return to original working directory
3538
cd $original_dir

0 commit comments

Comments
 (0)