You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
make dynamic scaling default in Float8Linear (#300)
Summary:
Pull Request resolved: #300
1. makes dynamic scaling default in Float8Linear for an easier migration
of callsites which currently use Float8DynamicLinear. Fixes
tests as needed.
2. updates the README to reference Float8Linear for dynamic scaling
Reviewed By: drisspg
Differential Revision: D59305790
fbshipit-source-id: 30d3813946239e0e958e0f7ed446082b578b0607
Copy file name to clipboardExpand all lines: README.md
+20-11Lines changed: 20 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -27,21 +27,23 @@ pip install -e ".[dev]"
27
27
28
28
# User API
29
29
30
-
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
30
+
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
31
31
32
-
## float8 linear with dynamic scaling
32
+
## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`
33
+
34
+
This is the most accurate recipe as every tensor is scaled dynamically.
33
35
34
36
```python
35
37
from float8_experimental.float8_linear_utils import (
36
38
swap_linear_with_float8_linear,
37
39
)
38
-
from float8_experimental.float8_dynamic_linearimportFloat8DynamicLinear
40
+
from float8_experimental.float8_linearimportFloat8Linear
39
41
40
42
# create model
41
43
m = Model(...)
42
44
43
-
# convert all `torch.nn.Linear` modules to `Float8DynamicLinear`
0 commit comments