You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/float8/README.md
+51Lines changed: 51 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -122,6 +122,57 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/
122
122
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
123
123
on using `torchao.float8` in a distributed setting.
124
124
125
+
# Performance
126
+
127
+
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
There are three observations we can make about the formula above:
162
+
* LHS > 0 for large shapes, with the gemm speedup approaching 2x as M, K, N increase
163
+
* LHS < 0 for small shapes, on NVIDIA H100 + cuBLAS
164
+
* RHS > 0 for all shapes, bounded by memory bandwidth, framework overhead and compiler limitations
165
+
166
+
For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1.
167
+
168
+
## Scaling type vs speedup
169
+
170
+
Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling.
171
+
172
+
## torch.compile behavior vs speedup
173
+
174
+
There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance.
0 commit comments