Skip to content

Commit e2dad4a

Browse files
authored
Add performance section to float8 README.md (#794)
1 parent 287458c commit e2dad4a

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

torchao/float8/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,57 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/
122122
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
123123
on using `torchao.float8` in a distributed setting.
124124

125+
# Performance
126+
127+
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
128+
129+
<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">
130+
131+
Example 1 (small shapes):
132+
* forward input tensor size 1024x2048, linear weight size 2048x1024; M, K, N = 1024, 2048, 1024
133+
* benchmark speedup is 0.80
134+
* recommendation: leave this linear in bfloat16, the shapes are too small to benefit from float8 compute
135+
136+
Example 2 (large shapes):
137+
* forward input tensor size 4096x8192, linear weight size 8192x16384; M, K, N = 4096, 8192, 16384
138+
* benchmark speedup is 1.39
139+
* recommendation: enable float8 for this linear to get a speedup
140+
141+
To reproduce the raw data for table above, you can run the following script
142+
143+
```lang=shell
144+
python benchmarks/float8/float8_roofline.py your_output_filename.csv --gemm_time_strategy benchmarks --shape_gen_name sweep
145+
```
146+
147+
## Derivation
148+
149+
In a bf16 linear, assume all of the time is spent in gemms. In a float8 linear, account for max_abs and casting overhead. We want to know when
150+
151+
```
152+
bf16_gemm_time > fp8_gemm_time + fp8_overhead_time
153+
```
154+
155+
Or, equivalently,
156+
157+
```
158+
bf16_gemm_time - fp8_gemm_time > fp8_overhead_time
159+
```
160+
161+
There are three observations we can make about the formula above:
162+
* LHS > 0 for large shapes, with the gemm speedup approaching 2x as M, K, N increase
163+
* LHS < 0 for small shapes, on NVIDIA H100 + cuBLAS
164+
* RHS > 0 for all shapes, bounded by memory bandwidth, framework overhead and compiler limitations
165+
166+
For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1.
167+
168+
## Scaling type vs speedup
169+
170+
Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling.
171+
172+
## torch.compile behavior vs speedup
173+
174+
There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance.
175+
125176
# Testing
126177

127178
```bash

0 commit comments

Comments
 (0)