Skip to content

Commit a7483f2

Browse files
authored
Add a prototype of MX format training and inference (#264)
Summary: The MX numerical formats are new low precision formats with recent acceptance into the OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication. Test Plan: ``` // tests pytest -s test/prototype/mx_formats/* // benchmarks python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 5b04ff0 commit a7483f2

File tree

16 files changed

+3181
-0
lines changed

16 files changed

+3181
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ To learn more try out our APIs, you can check out API examples in
113113
3. Support for lower precision [dtypes](./torchao/dtypes) such as
114114
- [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code
115115
- [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py)
116+
- [MX](https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
116117
4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees
117118
- [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning
118119
- [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads

dev-requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ transformers
99
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
1010
matplotlib
1111
pandas
12+
fire # QOL for commandline scripts
13+
tabulate # QOL for printing tables to stdout
1214

1315
# Custom CUDA Extensions
1416
ninja

0 commit comments

Comments
 (0)